Skip to content

Commit b98cb14

Browse files
committed
added exe loading support
1 parent ac32ef7 commit b98cb14

File tree

5 files changed

+388
-26
lines changed

5 files changed

+388
-26
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ if (YAIL_BUILD_EXAMPLES)
4040
set_target_properties(test_dll PROPERTIES CXX_STANDARD 23)
4141
target_link_libraries(test_dll PRIVATE yail dbghelp winmm delayimp)
4242
target_link_options(test_dll PRIVATE /DELAYLOAD:dbghelp.dll /DELAYLOAD:winmm.dll)
43+
44+
add_executable(exe_loader examples/exe_loader.cpp)
45+
set_target_properties(exe_loader PROPERTIES CXX_STANDARD 23)
46+
target_link_libraries(exe_loader PRIVATE yail)
47+
48+
add_executable(test_exe examples/test_exe.cpp)
49+
set_target_properties(test_exe PROPERTIES CXX_STANDARD 23)
50+
target_link_options(test_exe PRIVATE /DYNAMICBASE)
4351
endif()
4452

4553

examples/exe_loader.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "yail/yail.hpp"
2+
#include <Windows.h>
3+
#include <cstdio>
4+
#include <string>
5+
#include <print>
6+
7+
int main(int argc, char* argv[])
8+
{
9+
std::string exePath = "test_exe.exe";
10+
if (argc > 1)
11+
exePath = argv[1];
12+
13+
printf("[exe_loader] Manual-mapping: %s\n\n", exePath.c_str());
14+
15+
auto result = yail::manual_map_injection_from_file(exePath, GetCurrentProcessId());
16+
17+
if (!result)
18+
{
19+
std::println("[exe_loader] FAILED: {}", result.error());
20+
return 1;
21+
}
22+
23+
std::println("\n[exe_loader] Success - image loaded at 0x{:x}\n", result.value());
24+
return 0;
25+
}

examples/test_exe.cpp

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
#include <Windows.h>
2+
#include <cstdio>
3+
#include <cmath>
4+
#include <cstring>
5+
#include <string>
6+
#include <vector>
7+
#include <thread>
8+
#include <mutex>
9+
#include <atomic>
10+
#include <functional>
11+
#include <stdexcept>
12+
#include <memory>
13+
#include <algorithm>
14+
// =========================================================================
15+
// Globals & helpers
16+
// =========================================================================
17+
static int g_passed = 0;
18+
static int g_total = 0;
19+
20+
static void Report(const char* name, bool ok)
21+
{
22+
g_total++;
23+
if (ok) g_passed++;
24+
printf(" [%s] %s\n", ok ? "PASS" : "FAIL", name);
25+
}
26+
27+
// =========================================================================
28+
// 1. TLS callback — must fire before entry point
29+
// =========================================================================
30+
static bool g_tlsCallbackFired = false;
31+
32+
static void NTAPI TlsCallback(PVOID, DWORD dwReason, PVOID)
33+
{
34+
if (dwReason == DLL_PROCESS_ATTACH)
35+
g_tlsCallbackFired = true;
36+
}
37+
38+
#ifdef _MSC_VER
39+
#pragma comment(linker, "/INCLUDE:_tls_used")
40+
#pragma section(".CRT$XLB", read)
41+
__declspec(allocate(".CRT$XLB")) PIMAGE_TLS_CALLBACK g_tls_callback = TlsCallback;
42+
#endif
43+
44+
// =========================================================================
45+
// 2. Static TLS
46+
// =========================================================================
47+
static __declspec(thread) int g_tlsValue = 0;
48+
49+
static bool TestStaticTLS()
50+
{
51+
g_tlsValue = 42;
52+
return g_tlsValue == 42;
53+
}
54+
55+
static bool TestTLSPerThread()
56+
{
57+
g_tlsValue = 100;
58+
std::atomic<int> other_value{-1};
59+
60+
std::thread t([&]
61+
{
62+
other_value.store(g_tlsValue);
63+
g_tlsValue = 200;
64+
});
65+
t.join();
66+
67+
return other_value.load() == 0 && g_tlsValue == 100;
68+
}
69+
70+
// =========================================================================
71+
// 3. SEH tests
72+
// =========================================================================
73+
static bool TestSEH()
74+
{
75+
__try
76+
{
77+
*reinterpret_cast<volatile int*>(nullptr) = 0;
78+
}
79+
__except (GetExceptionCode() == EXCEPTION_ACCESS_VIOLATION
80+
? EXCEPTION_EXECUTE_HANDLER
81+
: EXCEPTION_CONTINUE_SEARCH)
82+
{
83+
return true;
84+
}
85+
return false;
86+
}
87+
88+
#pragma optimize("", off)
89+
static bool TestSEHDivZero()
90+
{
91+
__try
92+
{
93+
volatile int a = 1, b = 0;
94+
volatile int c = a / b;
95+
(void)c;
96+
}
97+
__except (GetExceptionCode() == EXCEPTION_INT_DIVIDE_BY_ZERO
98+
? EXCEPTION_EXECUTE_HANDLER
99+
: EXCEPTION_CONTINUE_SEARCH)
100+
{
101+
return true;
102+
}
103+
return false;
104+
}
105+
#pragma optimize("", on)
106+
107+
// =========================================================================
108+
// 4. C++ exception tests
109+
// =========================================================================
110+
static bool TestCppExceptionInt()
111+
{
112+
try
113+
{
114+
throw 42;
115+
}
116+
catch (int v)
117+
{
118+
return v == 42;
119+
}
120+
return false;
121+
}
122+
123+
static bool TestCppExceptionStd()
124+
{
125+
try
126+
{
127+
throw std::runtime_error("test");
128+
}
129+
catch (const std::exception& e)
130+
{
131+
return std::string(e.what()) == "test";
132+
}
133+
return false;
134+
}
135+
136+
struct StackGuard
137+
{
138+
bool& flag;
139+
~StackGuard() { flag = true; }
140+
};
141+
142+
static bool TestCppExceptionUnwind()
143+
{
144+
bool unwound = false;
145+
try
146+
{
147+
StackGuard guard{unwound};
148+
throw 1;
149+
}
150+
catch (...)
151+
{
152+
}
153+
return unwound;
154+
}
155+
156+
// =========================================================================
157+
// 5. Win32 API / imports
158+
// =========================================================================
159+
static bool TestImports()
160+
{
161+
const DWORD pid = GetCurrentProcessId();
162+
const DWORD tid = GetCurrentThreadId();
163+
SYSTEM_INFO si{};
164+
GetSystemInfo(&si);
165+
return pid != 0 && tid != 0 && si.dwPageSize != 0;
166+
}
167+
168+
// =========================================================================
169+
// 6. STL
170+
// =========================================================================
171+
static bool TestSTL()
172+
{
173+
std::vector<int> v = {5, 3, 1, 4, 2};
174+
std::sort(v.begin(), v.end());
175+
std::string s = "manual_map";
176+
return v == std::vector<int>{1, 2, 3, 4, 5} && s.size() == 10;
177+
}
178+
179+
// =========================================================================
180+
// 7. Floating point
181+
// =========================================================================
182+
static bool TestFloatingPoint()
183+
{
184+
const double val = std::sqrt(2.0);
185+
return std::abs(val - 1.41421356) < 0.0001;
186+
}
187+
188+
// =========================================================================
189+
// 8. VirtualAlloc
190+
// =========================================================================
191+
static bool TestVirtualMemory()
192+
{
193+
void* p = VirtualAlloc(nullptr, 4096, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
194+
if (!p) return false;
195+
*static_cast<int*>(p) = 0xDEAD;
196+
const bool ok = *static_cast<int*>(p) == 0xDEAD;
197+
VirtualFree(p, 0, MEM_RELEASE);
198+
return ok;
199+
}
200+
201+
// =========================================================================
202+
// 9. Threading
203+
// =========================================================================
204+
static bool TestThreading()
205+
{
206+
std::mutex mtx;
207+
int counter = 0;
208+
209+
auto worker = [&]
210+
{
211+
for (int i = 0; i < 1000; i++)
212+
{
213+
std::lock_guard lock(mtx);
214+
counter++;
215+
}
216+
};
217+
218+
std::thread t1(worker), t2(worker);
219+
t1.join();
220+
t2.join();
221+
return counter == 2000;
222+
}
223+
224+
// =========================================================================
225+
// 10. Global constructors
226+
// =========================================================================
227+
static std::string g_globalStr = "hello_from_exe";
228+
229+
static bool TestGlobalCtors()
230+
{
231+
return g_globalStr == "hello_from_exe";
232+
}
233+
234+
// =========================================================================
235+
// 11. Vtable dispatch
236+
// =========================================================================
237+
struct Base
238+
{
239+
virtual int value() = 0;
240+
virtual ~Base() = default;
241+
};
242+
243+
struct Derived final : Base
244+
{
245+
int value() override { return 99; }
246+
};
247+
248+
static bool TestVTable()
249+
{
250+
std::unique_ptr<Base> p = std::make_unique<Derived>();
251+
return p->value() == 99;
252+
}
253+
254+
// =========================================================================
255+
// 12. VEH
256+
// =========================================================================
257+
static bool g_vehCaught = false;
258+
259+
static LONG CALLBACK VehHandler(PEXCEPTION_POINTERS info)
260+
{
261+
if (info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION)
262+
{
263+
g_vehCaught = true;
264+
// Skip the faulting instruction (mov [rax], ecx = 2 bytes for the test below)
265+
info->ContextRecord->Rip += 2;
266+
return EXCEPTION_CONTINUE_EXECUTION;
267+
}
268+
return EXCEPTION_CONTINUE_SEARCH;
269+
}
270+
271+
static bool TestVEH()
272+
{
273+
g_vehCaught = false;
274+
void* h = AddVectoredExceptionHandler(1, VehHandler);
275+
if (!h) return false;
276+
277+
volatile int* bad = nullptr;
278+
*bad = 0;
279+
280+
RemoveVectoredExceptionHandler(h);
281+
return g_vehCaught;
282+
}
283+
284+
// =========================================================================
285+
// Entry point
286+
// =========================================================================
287+
int main()
288+
{
289+
printf("========================================\n");
290+
printf("[test_exe] Entry point reached\n");
291+
printf(" image base: %p\n", GetModuleHandleA(nullptr));
292+
printf("========================================\n\n");
293+
294+
Report("TLS callback fired", g_tlsCallbackFired);
295+
Report("Static TLS read/write", TestStaticTLS());
296+
Report("Static TLS per-thread", TestTLSPerThread());
297+
Report("SEH access violation", TestSEH());
298+
Report("SEH divide by zero", TestSEHDivZero());
299+
Report("C++ exception int", TestCppExceptionInt());
300+
Report("C++ exception std::exception", TestCppExceptionStd());
301+
Report("C++ exception stack unwind", TestCppExceptionUnwind());
302+
Report("Win32 API imports", TestImports());
303+
Report("STL containers/strings", TestSTL());
304+
Report("Floating point / math", TestFloatingPoint());
305+
Report("VirtualAlloc/Free", TestVirtualMemory());
306+
Report("Threading + mutex", TestThreading());
307+
Report("Global constructors", TestGlobalCtors());
308+
Report("Vtable dispatch", TestVTable());
309+
Report("Vectored exception handler", TestVEH());
310+
311+
printf("\n========================================\n");
312+
printf("[test_exe] Results: %d/%d passed\n", g_passed, g_total);
313+
printf("========================================\n");
314+
315+
MessageBoxA(nullptr, "All tests have passed!", "Yey", MB_OK);
316+
return 0;
317+
}

include/yail/yail.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ namespace yail
1212
{
1313
[[nodiscard]]
1414
std::expected<std::uintptr_t, std::string> manual_map_injection_from_raw(
15-
const std::span<std::uint8_t>& raw_dll, std::uintptr_t process_id);
15+
const std::span<std::uint8_t>& raw_image, std::uintptr_t process_id);
1616

1717
[[nodiscard]]
1818
std::expected<std::uintptr_t, std::string> manual_map_injection_from_raw(
19-
const std::span<std::uint8_t>& raw_dll, const std::string_view& process_name);
19+
const std::span<std::uint8_t>& raw_image, const std::string_view& process_name);
2020

2121
[[nodiscard]]
2222
std::expected<std::uintptr_t, std::string> manual_map_injection_from_file(
23-
const std::string_view& dll_path, std::uintptr_t process_id);
23+
const std::string_view& image_path, std::uintptr_t process_id);
2424

2525
[[nodiscard]]
2626
std::expected<std::uintptr_t, std::string> manual_map_injection_from_file(
27-
const std::string_view& dll_path, const std::string_view& process_name);
27+
const std::string_view& image_path, const std::string_view& process_name);
2828
}

0 commit comments

Comments
 (0)