Skip to content

Commit 0c2d2f3

Browse files
author
Vlada Kanivets
committed
add tests for FltCommunicationPort
1 parent 83448bc commit 0c2d2f3

7 files changed

Lines changed: 514 additions & 21 deletions

File tree

include/kf/FltCommunicationPort.h

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ObjectAttributes.h"
33
#include "ScopeExit.h"
44
#include "VariableSizeStruct.h"
5+
#include "IWinApi.h"
56

67
namespace kf
78
{
@@ -69,7 +70,7 @@ namespace kf
6970
class FltCommunicationPort
7071
{
7172
public:
72-
FltCommunicationPort() : m_filter(), m_port()
73+
FltCommunicationPort(IWinApi& api) : m_api(api), m_filter(), m_port()
7374
{
7475
}
7576

@@ -85,26 +86,26 @@ namespace kf
8586
m_filter = filter;
8687

8788
PSECURITY_DESCRIPTOR securityDescriptor = nullptr;
88-
NTSTATUS status = ::FltBuildDefaultSecurityDescriptor(&securityDescriptor, FLT_PORT_ALL_ACCESS);
89+
NTSTATUS status = m_api.FltBuildDefaultSecurityDescriptor(&securityDescriptor, FLT_PORT_ALL_ACCESS);
8990

9091
if (!NT_SUCCESS(status))
9192
{
9293
return status;
9394
}
9495

95-
SCOPE_EXIT{ ::FltFreeSecurityDescriptor(securityDescriptor); };
96+
SCOPE_EXIT{ m_api.FltFreeSecurityDescriptor(securityDescriptor); };
9697

9798
VariableSizeStruct<SYSTEM_MANDATORY_LABEL_ACE, PagedPool> lowIntegrityAce;
9899
VariableSizeStruct<ACL, PagedPool> sacl;
99100
if (allowNonAdmins)
100101
{
101-
status = RtlSetDaclSecurityDescriptor(securityDescriptor, true, nullptr, false);
102+
status = m_api.RtlSetDaclSecurityDescriptor(securityDescriptor, true, nullptr, false);
102103
if (!NT_SUCCESS(status))
103104
{
104105
return status;
105106
}
106107

107-
const auto lowMandatorySidLength = RtlLengthSid(SeExports->SeLowMandatorySid);
108+
const auto lowMandatorySidLength = m_api.RtlLengthSid(SeExports->SeLowMandatorySid);
108109
status = lowIntegrityAce.emplace(FIELD_OFFSET(SYSTEM_MANDATORY_LABEL_ACE, SidStart) + lowMandatorySidLength);
109110
if (!NT_SUCCESS(status))
110111
{
@@ -114,7 +115,7 @@ namespace kf
114115
lowIntegrityAce->Header.AceType = SYSTEM_MANDATORY_LABEL_ACE_TYPE;
115116
lowIntegrityAce->Header.AceSize = static_cast<USHORT>(FIELD_OFFSET(SYSTEM_MANDATORY_LABEL_ACE, SidStart) + lowMandatorySidLength);
116117
lowIntegrityAce->Mask = 0;
117-
status = RtlCopySid(lowMandatorySidLength, &lowIntegrityAce->SidStart, SeExports->SeLowMandatorySid);
118+
status = m_api.RtlCopySid(lowMandatorySidLength, &lowIntegrityAce->SidStart, SeExports->SeLowMandatorySid);
118119
if (!NT_SUCCESS(status))
119120
{
120121
return status;
@@ -126,19 +127,19 @@ namespace kf
126127
{
127128
return status;
128129
}
129-
status = RtlCreateAcl(sacl.get(), saclSize, ACL_REVISION);
130+
status = m_api.RtlCreateAcl(sacl.get(), saclSize, ACL_REVISION);
130131
if (!NT_SUCCESS(status))
131132
{
132133
return status;
133134
}
134135

135-
status = RtlAddAce(sacl.get(), ACL_REVISION, 0, static_cast<PVOID>(lowIntegrityAce.get()), lowIntegrityAce->Header.AceSize);
136+
status = m_api.RtlAddAce(sacl.get(), ACL_REVISION, 0, static_cast<PVOID>(lowIntegrityAce.get()), lowIntegrityAce->Header.AceSize);
136137
if (!NT_SUCCESS(status))
137138
{
138139
return status;
139140
}
140141

141-
status = RtlSetSaclSecurityDescriptor(securityDescriptor, true, sacl.get(), false);
142+
status = m_api.RtlSetSaclSecurityDescriptor(securityDescriptor, true, sacl.get(), false);
142143
if (!NT_SUCCESS(status))
143144
{
144145
return status;
@@ -147,14 +148,14 @@ namespace kf
147148

148149
ObjectAttributes oa(&name, securityDescriptor);
149150

150-
return ::FltCreateCommunicationPort(filter, &m_port, &oa, this, connectNotify, disconnectNotify, messageNotify, maxConnections);
151+
return m_api.FltCreateCommunicationPort(filter, &m_port, &oa, this, connectNotify, disconnectNotify, messageNotify, maxConnections);
151152
}
152153

153154
void close()
154155
{
155156
if (m_port)
156157
{
157-
::FltCloseCommunicationPort(m_port);
158+
m_api.FltCloseCommunicationPort(m_port);
158159
m_port = nullptr;
159160
}
160161

@@ -175,7 +176,7 @@ namespace kf
175176
{
176177
ASSERT(serverPortCookie);
177178
auto self = static_cast<FltCommunicationPort*>(serverPortCookie);
178-
return Handler::onConnect(self->m_filter, clientPort, connectionContext, connectionContextLength, reinterpret_cast<Handler**>(connectionCookie));
179+
return Handler::onConnect(self->m_filter, clientPort, connectionContext, connectionContextLength, reinterpret_cast<Handler**>(connectionCookie), self->m_api);
179180
}
180181

181182
static VOID FLTAPI disconnectNotify(
@@ -214,15 +215,15 @@ namespace kf
214215
{
215216
if (inputBufferLength)
216217
{
217-
inputMdl = IoAllocateMdl(inputBuffer, inputBufferLength, false, false, nullptr);
218+
inputMdl = handler->m_api.IoAllocateMdl(inputBuffer, inputBufferLength, false, false, nullptr);
218219
if (!inputMdl)
219220
{
220221
return STATUS_INSUFFICIENT_RESOURCES;
221222
}
222223

223-
MmProbeAndLockPages(inputMdl, KernelMode, IoReadAccess);
224+
handler->m_api.MmProbeAndLockPages(inputMdl, KernelMode, IoReadAccess);
224225

225-
inputBuffer = MmGetSystemAddressForMdlSafe(inputMdl, NormalPagePriority | MdlMappingNoExecute | MdlMappingNoWrite);
226+
inputBuffer = handler->m_api.MmGetSystemAddressForMdlSafe(inputMdl, NormalPagePriority | MdlMappingNoExecute | MdlMappingNoWrite);
226227
if (!inputBuffer)
227228
{
228229
return STATUS_INSUFFICIENT_RESOURCES;
@@ -231,15 +232,15 @@ namespace kf
231232

232233
if (outputBufferLength)
233234
{
234-
outputMdl = IoAllocateMdl(outputBuffer, outputBufferLength, false, false, nullptr);
235+
outputMdl = handler->m_api.IoAllocateMdl(outputBuffer, outputBufferLength, false, false, nullptr);
235236
if (!outputMdl)
236237
{
237238
return STATUS_INSUFFICIENT_RESOURCES;
238239
}
239240

240-
MmProbeAndLockPages(outputMdl, KernelMode, IoWriteAccess);
241+
handler->m_api.MmProbeAndLockPages(outputMdl, KernelMode, IoWriteAccess);
241242

242-
outputBuffer = MmGetSystemAddressForMdlSafe(outputMdl, NormalPagePriority | MdlMappingNoExecute);
243+
outputBuffer = handler->m_api.MmGetSystemAddressForMdlSafe(outputMdl, NormalPagePriority | MdlMappingNoExecute);
243244
if (!outputBuffer)
244245
{
245246
return STATUS_INSUFFICIENT_RESOURCES;
@@ -258,16 +259,16 @@ namespace kf
258259
// Cleanup
259260
//
260261

261-
auto freeMdl = [](PMDL& mdl)
262+
auto freeMdl = [&handler](PMDL& mdl)
262263
{
263264
if (mdl)
264265
{
265266
if (FlagOn(mdl->MdlFlags, MDL_PAGES_LOCKED))
266267
{
267-
MmUnlockPages(mdl);
268+
handler->m_api.MmUnlockPages(mdl);
268269
}
269270

270-
IoFreeMdl(mdl);
271+
handler->m_api.IoFreeMdl(mdl);
271272
mdl = nullptr;
272273
}
273274
};
@@ -281,5 +282,6 @@ namespace kf
281282
private:
282283
PFLT_FILTER m_filter;
283284
PFLT_PORT m_port;
285+
IWinApi& m_api;
284286
};
285287
} // namespace

include/kf/IWinApi.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
namespace kf
4+
{
5+
////////////////////////////////////////////////////
6+
// Interface for Windows API calls
7+
class IWinApi
8+
{
9+
public:
10+
virtual NTSTATUS FltBuildDefaultSecurityDescriptor(PSECURITY_DESCRIPTOR* sd, ACCESS_MASK access) = 0;
11+
12+
virtual VOID FltFreeSecurityDescriptor(PSECURITY_DESCRIPTOR sd) = 0;
13+
14+
virtual NTSTATUS FltCreateCommunicationPort(
15+
PFLT_FILTER filter,
16+
PFLT_PORT* serverPort,
17+
POBJECT_ATTRIBUTES oa,
18+
PVOID serverPortCookie,
19+
PFLT_CONNECT_NOTIFY connectNotify,
20+
PFLT_DISCONNECT_NOTIFY disconnectNotify,
21+
PFLT_MESSAGE_NOTIFY messageNotify,
22+
LONG maxConnections) = 0;
23+
24+
virtual VOID FltCloseCommunicationPort(PFLT_PORT port) = 0;
25+
26+
virtual VOID FltCloseClientPort(PFLT_FILTER filter, PFLT_PORT* port) = 0;
27+
28+
virtual NTSTATUS RtlSetDaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN daclPresent, PACL dacl, BOOLEAN daclDefaulted) = 0;
29+
30+
virtual NTSTATUS RtlSetSaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN saclPresent, PACL sacl, BOOLEAN saclDefaulted) = 0;
31+
32+
virtual ULONG RtlLengthSid(PSID sid) = 0;
33+
34+
virtual NTSTATUS RtlCopySid(ULONG len, PSID dest, PSID src) = 0;
35+
36+
virtual NTSTATUS RtlCreateAcl(PACL acl, ULONG size, ULONG rev) = 0;
37+
38+
virtual NTSTATUS RtlAddAce(PACL acl, ULONG rev, ULONG start, PVOID ace, ULONG aceSize) = 0;
39+
40+
virtual PMDL IoAllocateMdl(PVOID va, ULONG len, BOOLEAN secondary, BOOLEAN chargeQuota, PIRP irp) = 0;
41+
42+
virtual VOID IoFreeMdl(PMDL mdl) = 0;
43+
44+
virtual VOID MmProbeAndLockPages(PMDL mdl, KPROCESSOR_MODE mode, LOCK_OPERATION op) = 0;
45+
46+
virtual VOID MmUnlockPages(PMDL mdl) = 0;
47+
48+
virtual PVOID MmGetSystemAddressForMdlSafe(PMDL mdl, ULONG priority) = 0;
49+
};
50+
}

include/kf/WinApi.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#pragma once
2+
#include "IWinApi.h"
3+
4+
namespace kf
5+
{
6+
//////////////////////////////////////////////////////////////////////////
7+
// Wrapper for Windows API calls to allow mocking in unit tests
8+
class WinApi : public IWinApi
9+
{
10+
public:
11+
NTSTATUS FltBuildDefaultSecurityDescriptor(PSECURITY_DESCRIPTOR* sd, ACCESS_MASK access)
12+
{
13+
return ::FltBuildDefaultSecurityDescriptor(sd, access);
14+
}
15+
16+
VOID FltFreeSecurityDescriptor(PSECURITY_DESCRIPTOR sd)
17+
{
18+
::FltFreeSecurityDescriptor(sd);
19+
}
20+
NTSTATUS FltCreateCommunicationPort(
21+
PFLT_FILTER filter,
22+
PFLT_PORT* serverPort,
23+
POBJECT_ATTRIBUTES oa,
24+
PVOID serverPortCookie,
25+
PFLT_CONNECT_NOTIFY connectNotify,
26+
PFLT_DISCONNECT_NOTIFY disconnectNotify,
27+
PFLT_MESSAGE_NOTIFY messageNotify,
28+
LONG maxConnections)
29+
{
30+
return ::FltCreateCommunicationPort(filter, serverPort, oa, serverPortCookie, connectNotify, disconnectNotify, messageNotify, maxConnections);
31+
}
32+
33+
VOID FltCloseCommunicationPort(PFLT_PORT port)
34+
{
35+
::FltCloseCommunicationPort(port);
36+
}
37+
38+
VOID FltCloseClientPort(PFLT_FILTER filter, PFLT_PORT* port)
39+
{
40+
::FltCloseClientPort(filter, port);
41+
}
42+
43+
NTSTATUS RtlSetDaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN daclPresent, PACL dacl, BOOLEAN daclDefaulted)
44+
{
45+
return ::RtlSetDaclSecurityDescriptor(sd, daclPresent, dacl, daclDefaulted);
46+
}
47+
48+
NTSTATUS RtlSetSaclSecurityDescriptor(PSECURITY_DESCRIPTOR sd, BOOLEAN saclPresent, PACL sacl, BOOLEAN saclDefaulted)
49+
{
50+
return ::RtlSetSaclSecurityDescriptor(sd, saclPresent, sacl, saclDefaulted);
51+
}
52+
53+
ULONG RtlLengthSid(PSID sid)
54+
{
55+
return ::RtlLengthSid(sid);
56+
}
57+
58+
NTSTATUS RtlCopySid(ULONG len, PSID dest, PSID src)
59+
{
60+
return ::RtlCopySid(len, dest, src);
61+
}
62+
63+
NTSTATUS RtlCreateAcl(PACL acl, ULONG size, ULONG rev)
64+
{
65+
return ::RtlCreateAcl(acl, size, rev);
66+
}
67+
68+
NTSTATUS RtlAddAce(PACL acl, ULONG rev, ULONG start, PVOID ace, ULONG aceSize)
69+
{
70+
return ::RtlAddAce(acl, rev, start, ace, aceSize);
71+
}
72+
73+
PMDL IoAllocateMdl(PVOID va, ULONG len, BOOLEAN secondary, BOOLEAN chargeQuota, PIRP irp)
74+
{
75+
return ::IoAllocateMdl(va, len, secondary, chargeQuota, irp);
76+
}
77+
78+
VOID IoFreeMdl(PMDL mdl)
79+
{
80+
::IoFreeMdl(mdl);
81+
}
82+
83+
VOID MmProbeAndLockPages(PMDL mdl, KPROCESSOR_MODE mode, LOCK_OPERATION op)
84+
{
85+
::MmProbeAndLockPages(mdl, mode, op);
86+
}
87+
88+
VOID MmUnlockPages(PMDL mdl)
89+
{
90+
::MmUnlockPages(mdl);
91+
}
92+
93+
PVOID MmGetSystemAddressForMdlSafe(PMDL mdl, ULONG priority)
94+
{
95+
return ::MmGetSystemAddressForMdlSafe(mdl, priority);
96+
}
97+
};
98+
}

test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ wdk_add_driver(kf-test WINVER NTDDI_WIN10 STL
5959
AutoSpinLockTest.cpp
6060
EResourceSharedLockTest.cpp
6161
RecursiveAutoSpinLockTest.cpp
62+
FltCommunicationPortTest.cpp
63+
WinApiMock.h
6264
)
6365

6466
target_link_libraries(kf-test kf::kf kmtest::kmtest)

0 commit comments

Comments
 (0)