Skip to content

Commit 3109e01

Browse files
committed
ARM64EC: Optimize GPR and MM state setting
I think the code improvement speaks for itself here.
1 parent c6d2ce0 commit 3109e01

1 file changed

Lines changed: 49 additions & 20 deletions

File tree

Source/Windows/ARM64EC/Module.cpp

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ desc: Implements the ARM64EC BT module API using FEXCore
4646
#include "BTInterface.h"
4747
#include "Windows/Common/SHMStats.h"
4848

49+
#include <arm_neon.h>
4950
#include <cstdint>
5051
#include <cstdio>
5152
#include <type_traits>
@@ -311,32 +312,46 @@ static bool HandleUnalignedAccess(const ThreadCPUArea CPUArea, ARM64_NT_CONTEXT&
311312
return Result.has_value();
312313
}
313314

314-
static void LoadStateFromECContext(FEXCore::Core::InternalThreadState* Thread, CONTEXT& Context) {
315+
static void LoadStateFromECContext(FEXCore::Core::InternalThreadState* Thread, const CONTEXT& Context) {
315316
auto& State = Thread->CurrentFrame->State;
316317

317318
if ((Context.ContextFlags & CONTEXT_INTEGER) == CONTEXT_INTEGER) {
319+
// Ensure ordering.
320+
static_assert(((offsetof(CONTEXT, Rax) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RAX);
321+
static_assert(((offsetof(CONTEXT, Rcx) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RCX);
322+
static_assert(((offsetof(CONTEXT, Rdx) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RDX);
323+
static_assert(((offsetof(CONTEXT, Rbx) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RBX);
324+
static_assert(((offsetof(CONTEXT, Rsp) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RSP);
325+
static_assert(((offsetof(CONTEXT, Rbp) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RBP);
326+
static_assert(((offsetof(CONTEXT, Rsi) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RSI);
327+
static_assert(((offsetof(CONTEXT, Rdi) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RDI);
328+
static_assert(((offsetof(CONTEXT, R8) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R8);
329+
static_assert(((offsetof(CONTEXT, R9) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R9);
330+
static_assert(((offsetof(CONTEXT, R10) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R10);
331+
static_assert(((offsetof(CONTEXT, R11) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R11);
332+
static_assert(((offsetof(CONTEXT, R12) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R12);
333+
static_assert(((offsetof(CONTEXT, R13) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R13);
334+
static_assert(((offsetof(CONTEXT, R14) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R14);
335+
static_assert(((offsetof(CONTEXT, R15) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_R15);
336+
318337
// General register state
319-
State.gregs[FEXCore::X86State::REG_RAX] = Context.Rax;
320-
State.gregs[FEXCore::X86State::REG_RCX] = Context.Rcx;
321-
State.gregs[FEXCore::X86State::REG_RDX] = Context.Rdx;
322-
State.gregs[FEXCore::X86State::REG_RBX] = Context.Rbx;
323-
324-
State.gregs[FEXCore::X86State::REG_RSI] = Context.Rsi;
325-
State.gregs[FEXCore::X86State::REG_RDI] = Context.Rdi;
326-
State.gregs[FEXCore::X86State::REG_R8] = Context.R8;
327-
State.gregs[FEXCore::X86State::REG_R9] = Context.R9;
328-
State.gregs[FEXCore::X86State::REG_R10] = Context.R10;
329-
State.gregs[FEXCore::X86State::REG_R11] = Context.R11;
330-
State.gregs[FEXCore::X86State::REG_R12] = Context.R12;
331-
State.gregs[FEXCore::X86State::REG_R13] = Context.R13;
332-
State.gregs[FEXCore::X86State::REG_R14] = Context.R14;
333-
State.gregs[FEXCore::X86State::REG_R15] = Context.R15;
338+
auto Src = reinterpret_cast<const uint64_t*>(&Context.Rax);
339+
auto Dst = reinterpret_cast<uint64_t*>(&State.gregs[FEXCore::X86State::REG_RAX]);
340+
341+
asm volatile(R"(
342+
ld1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Src]], #64;
343+
st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Dst]], #64;
344+
ld1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Src]], #64;
345+
st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Dst]], #64;
346+
)"
347+
: [Src] "+r"(Src), [Dst] "+r"(Dst)::"memory", "v0", "v1", "v2", "v3");
334348
}
335349

336350
if ((Context.ContextFlags & CONTEXT_CONTROL) == CONTEXT_CONTROL) {
337351
State.rip = Context.Rip;
338-
State.gregs[FEXCore::X86State::REG_RSP] = Context.Rsp;
339-
State.gregs[FEXCore::X86State::REG_RBP] = Context.Rbp;
352+
static_assert(((offsetof(CONTEXT, Rsp) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RSP);
353+
static_assert(((offsetof(CONTEXT, Rbp) - offsetof(CONTEXT, Rax)) / sizeof(uint64_t)) == FEXCore::X86State::REG_RBP);
354+
memcpy(&State.gregs[FEXCore::X86State::REG_RSP], &Context.Rsp, sizeof(uint64_t) * 2);
340355
CTX->SetFlagsFromCompactedEFLAGS(Thread, Context.EFlags);
341356
}
342357

@@ -364,13 +379,27 @@ static void LoadStateFromECContext(FEXCore::Core::InternalThreadState* Thread, C
364379
if ((Context.ContextFlags & CONTEXT_FLOATING_POINT) == CONTEXT_FLOATING_POINT) {
365380
// Floating-point register state
366381
if ((Context.ContextFlags & CONTEXT_XSTATE) == CONTEXT_XSTATE) {
367-
const auto* Ymm = RtlLocateExtendedFeature(reinterpret_cast<CONTEXT_EX*>(&Context + 1), XSTATE_AVX, nullptr);
382+
auto Ymm = RtlLocateExtendedFeature(const_cast<CONTEXT_EX*>(reinterpret_cast<const CONTEXT_EX*>(&Context + 1)), XSTATE_AVX, nullptr);
368383
CTX->SetXMMRegistersFromState(Thread, reinterpret_cast<const __uint128_t*>(Context.FltSave.XmmRegisters),
369384
reinterpret_cast<const __uint128_t*>(Ymm));
370385
} else {
371386
CTX->SetXMMRegistersFromState(Thread, reinterpret_cast<const __uint128_t*>(Context.FltSave.XmmRegisters), nullptr);
372387
}
373-
memcpy(State.mm, Context.FltSave.FloatRegisters, sizeof(State.mm));
388+
389+
// Sanity check to make sure padding is correct.
390+
static_assert(sizeof(State.mm[0]) == 16);
391+
392+
// X87 registers
393+
auto Src = reinterpret_cast<const uint64_t*>(Context.FltSave.FloatRegisters);
394+
auto Dst = reinterpret_cast<uint64_t*>(State.mm);
395+
396+
asm volatile(R"(
397+
ld1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Src]], #64;
398+
st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Dst]], #64;
399+
ld1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Src]], #64;
400+
st1 {v0.2d, v1.2d, v2.2d, v3.2d}, [%[Dst]], #64;
401+
)"
402+
: [Src] "+r"(Src), [Dst] "+r"(Dst)::"memory", "v0", "v1", "v2", "v3");
374403

375404
State.FCW = Context.FltSave.ControlWord;
376405
State.flags[FEXCore::X86State::X87FLAG_IE_LOC] = Context.FltSave.StatusWord & 1;

0 commit comments

Comments
 (0)