|
21 | 21 | #endif |
22 | 22 |
|
23 | 23 | #include <cassert> |
| 24 | +#include <cstddef> |
| 25 | +#include <cstdint> |
24 | 26 |
|
25 | 27 | namespace DynLibUtils { |
26 | 28 |
|
27 | | -class MemUnprotector |
| 29 | +using ProtectFlags_t = unsigned long; |
| 30 | + |
| 31 | +class VirtualUnprotector |
28 | 32 | { |
29 | 33 | public: |
30 | | - MemUnprotector(CMemory pAddress, size_t nLength = sizeof(void*)) |
| 34 | + VirtualUnprotector(void *pTarget, std::size_t nLength = sizeof(void*)) |
31 | 35 | { |
32 | 36 | #if _WIN32 |
33 | | - VirtualProtect(pAddress, nLength, PAGE_EXECUTE_READWRITE, &m_nOldProtection); |
34 | | - |
35 | | - m_pAddress = pAddress; |
36 | 37 | m_nLength = nLength; |
| 38 | + m_pTarget = pTarget; |
| 39 | + |
| 40 | + assert(VirtualProtect(pTarget, nLength, PAGE_EXECUTE_READWRITE, &m_nOldProtect)); |
37 | 41 | #else |
38 | | - long nPageSize = sysconf(_SC_PAGESIZE); |
39 | | - CMemory pPageStart = pAddress & ~(nPageSize - 1l); |
40 | | - CMemory pPageEnd = (pAddress + nLength + nPageSize - 1l) & ~(nPageSize - 1l); |
41 | | - size_t nAligned = pPageEnd - pPageStart; |
| 42 | + long pageSize = sysconf(_SC_PAGESIZE); |
| 43 | + |
| 44 | + assert(pageSize >= 0); |
42 | 45 |
|
43 | | - mprotect(pPageStart, nAligned, PROT_READ | PROT_WRITE | PROT_EXEC); |
| 46 | + auto nPageSize = static_cast<std::uintptr_t>(pageSize); |
44 | 47 |
|
45 | | - m_pAddress = pPageStart; |
| 48 | + auto pAddress = reinterpret_cast<std::uintptr_t>(pTarget); |
| 49 | + CMemory pPageStart = pAddress & ~(nPageSize - 1l); |
| 50 | + std::uintptr_t pPageEnd = (pAddress + nLength + nPageSize - 1l) & ~(nPageSize - 1l); |
| 51 | + auto nAligned = static_cast<std::size_t>(pPageEnd - pPageStart); |
| 52 | + |
| 53 | + m_nOldProtect = PROT_READ; |
46 | 54 | m_nLength = nAligned; |
47 | | - m_nOldProtection = PROT_READ | PROT_WRITE; //TODO: Need to parse /proc/self/maps |
| 55 | + m_pTarget = pPageStart; |
| 56 | + |
| 57 | + assert(!mprotect(pPageStart, nAligned, PROT_READ | PROT_WRITE)); |
48 | 58 | #endif |
49 | 59 | } |
50 | 60 |
|
51 | | - ~MemUnprotector() |
| 61 | + ~VirtualUnprotector() |
52 | 62 | { |
53 | 63 | #if _WIN32 |
54 | 64 | DWORD origProtect; |
55 | | - VirtualProtect(m_pAddress, m_nLength, m_nOldProtection, &origProtect); |
| 65 | + assert(VirtualProtect(m_pTarget, m_nLength, m_nOldProtect, &origProtect)); |
56 | 66 | #else |
57 | | - mprotect(m_pAddress, m_nLength, m_nOldProtection); |
| 67 | + assert(!mprotect(m_pTarget, m_nLength, m_nOldProtect)); |
58 | 68 | #endif |
59 | 69 | } |
60 | 70 |
|
61 | 71 | private: |
62 | | - size_t m_nLength; |
63 | | - unsigned long m_nOldProtection; |
64 | | - |
65 | | - CMemory m_pAddress; |
66 | | -}; // class MemUnprotector |
| 72 | + ProtectFlags_t m_nOldProtect; |
| 73 | + std::size_t m_nLength; |
| 74 | + CMemory m_pTarget; |
| 75 | +}; // class VirtualUnprotector |
67 | 76 |
|
68 | 77 | template <typename T, typename C, typename ...Args> |
69 | 78 | class VTHook |
@@ -102,14 +111,14 @@ class VTHook |
102 | 111 | protected: |
103 | 112 | void HookImpl(T(*pfnTarget)(C*, Args...)) |
104 | 113 | { |
105 | | - MemUnprotector unprotector(m_vmpFn); |
| 114 | + VirtualUnprotector unprotect(m_vmpFn); |
106 | 115 |
|
107 | 116 | *m_vmpFn.RCast<T(**)(C*, Args...)>() = pfnTarget; |
108 | 117 | } |
109 | 118 |
|
110 | 119 | void UnhookImpl() |
111 | 120 | { |
112 | | - MemUnprotector unprotector(m_vmpFn); |
| 121 | + VirtualUnprotector unprotect(m_vmpFn); |
113 | 122 |
|
114 | 123 | *m_vmpFn.RCast<void **>() = m_pOriginalFn; |
115 | 124 | } |
|
0 commit comments