Skip to content

Commit 82dd241

Browse files
committed
Add GetVirtualMethodIndex method
& update `CVirtualTable` class
1 parent f720714 commit 82dd241

2 files changed

Lines changed: 74 additions & 17 deletions

File tree

include/dynlibutils/virtual.hpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,70 @@
77

88
#pragma once
99

10+
#include "memaddr.hpp"
11+
1012
#include <cstddef>
13+
#include <cstdint>
14+
#include <bit>
15+
#include <type_traits>
16+
17+
#define DYNLIB_INVALID_VMETHOD_INDEX -1
1118

1219
namespace DynLibUtils {
1320

14-
class CVirtualTable
21+
template<auto METHOD, class T>
22+
inline std::ptrdiff_t GetVirtualMethodIndex(T *pClass) noexcept
23+
{
24+
static_assert(std::is_member_function_pointer_v<decltype(METHOD)>, "Templated method must be a pointer-to-member-function");
25+
26+
#if defined(_MSC_VER)
27+
// --- MSVC ABI: runtime scan vtable (no constexpr!) ---
28+
struct MSVCPMF { void* ptr; std::ptrdiff_t adj; };
29+
union { decltype(METHOD) m; MSVCPMF pmf; } u{ METHOD };
30+
void* target = u.pmf.ptr;
31+
32+
void** vtbl = *reinterpret_cast<void***>(pClass);
33+
34+
constexpr std::size_t header = 2; // [RTTI][offset-to-top]
35+
36+
for (std::size_t n = header; ; ++n)
37+
if (vtbl[n] == target)
38+
return n - header;
39+
40+
#elif defined(__GNUG__) || defined(__clang__)
41+
// --- Itanium C++ ABI: PMF.ptr = 1 + offset_in_bytes for virtual ones ---
42+
struct ItaniumPMF { void* ptr; std::ptrdiff_t adj; };
43+
union { decltype(METHOD) m; ItaniumPMF pmf; } u{ METHOD };
44+
45+
constexpr auto raw = reinterpret_cast<std::uintptr_t>(u.pmf.ptr);
46+
47+
static_assert((raw & 1u) != 0, "Not a virtual member function");
48+
49+
return (raw - 1u) / sizeof(void*);
50+
#else
51+
static_assert(false, "Unsupported compiler");
52+
#endif
53+
54+
return DYNLIB_INVALID_VMETHOD_INDEX;
55+
}
56+
57+
class CVirtualTable : public CMemoryView<void *>
1558
{
1659
public: // Types.
60+
using CBase = CMemoryView<void *>;
1761
using CThis = CVirtualTable;
1862

1963
public: // Constructors.
20-
CVirtualTable() : m_pVTFs(nullptr) {}
21-
template<class T> CVirtualTable(T *pClass) : m_pVTFs(*reinterpret_cast<void ***>(pClass)) {}
64+
CVirtualTable() : CBase(nullptr) {}
65+
template<class T> CVirtualTable(T *pClass) : CBase(*reinterpret_cast<void **>(pClass)) {}
2266

2367
public: // Getters.
24-
template<typename R> R &GetMethod(std::ptrdiff_t nIndex) { return reinterpret_cast<R &>(m_pVTFs[nIndex]); }
25-
template<typename R> R GetMethod(std::ptrdiff_t nIndex) const { return reinterpret_cast<R>(m_pVTFs[nIndex]); }
68+
template<typename R> R &GetMethod(std::ptrdiff_t nIndex) { return CBase::Offset(nIndex).RCast<R &>(); }
69+
template<typename R> R GetMethod(std::ptrdiff_t nIndex) const { return CBase::Offset(nIndex).RCast<R>(); }
2670

2771
public: // Callers.
2872
template<typename R, typename... Args> R CallMethod(std::ptrdiff_t nIndex, Args... args) { return GetMethod<R (*)(void *, Args...)>(nIndex)(this, args...); }
2973
template<typename R, typename... Args> R CallMethod(std::ptrdiff_t nIndex, Args... args) const { return const_cast<CThis *>(this)->CallMethod(nIndex, args...); }
30-
31-
void **m_pVTFs;
3274
}; // class CVirtualTable
3375

3476
class VirtualTable final : public CVirtualTable

include/dynlibutils/vthook.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include "memaddr.hpp"
12+
#include "virtual.hpp"
1213

1314
#if _WIN32
1415
# define WIN32_LEAN_AND_MEAN
@@ -71,33 +72,47 @@ class VTHook
7172
VTHook() = default;
7273
~VTHook() { Unhook(); }
7374

75+
void Clear()
76+
{
77+
m_vmpFn = nullptr;
78+
m_pOriginalFn = nullptr;
79+
}
80+
7481
bool IsHooked() const { return m_pOriginalFn.IsValid(); }
82+
T Call(C* pThis, Args... args) { return m_pOriginalFn.RCast<T(*)(C*, Args...)>()(pThis, args...); }
7583

76-
void Hook(CMemory pVTable, int index, T(*pFn)(C*, Args...))
84+
void Hook(CVirtualTable pVTable, std::ptrdiff_t nIndex, T(*pFn)(C*, Args...))
7785
{
7886
assert(!IsHooked());
7987

80-
m_vmpFn = pVTable.Offset(index * sizeof(void*));
88+
m_vmpFn = pVTable.Offset(nIndex);
8189
m_pOriginalFn = m_vmpFn.Deref();
8290

83-
MemUnprotector unprotector(m_vmpFn);
84-
85-
*m_vmpFn.RCast<T(**)(C*, Args...)>() = pFn;
91+
HookImpl(pFn);
8692
}
8793

8894
void Unhook()
8995
{
9096
assert(IsHooked());
9197

92-
MemUnprotector unprotector(m_vmpFn);
98+
UnhookImpl();
99+
Clear();
100+
}
93101

94-
*m_vmpFn.RCast<void**>() = m_pOriginalFn;
102+
protected:
103+
void HookImpl(T(*pfnTarget)(C*, Args...))
104+
{
105+
MemUnprotector unprotector(m_vmpFn);
95106

96-
m_vmpFn = nullptr;
97-
m_pOriginalFn = nullptr;
107+
*m_vmpFn.RCast<T(**)(C*, Args...)>() = pfnTarget;
98108
}
99109

100-
T Call(C* pThis, Args... args) { return m_pOriginalFn.RCast<T(*)(C*, Args...)>()(pThis, args...); }
110+
void UnhookImpl()
111+
{
112+
MemUnprotector unprotector(m_vmpFn);
113+
114+
*m_vmpFn.RCast<void **>() = m_pOriginalFn;
115+
}
101116

102117
private:
103118
CMemory m_vmpFn;

0 commit comments

Comments
 (0)