1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . Globalization ;
4+ using System . Linq ;
5+ using System . Net ;
6+ using System . Text . Json ;
7+ using System . Threading . Tasks ;
8+ using DnsServerCore . ApplicationCommon ;
9+ using TechnitiumLibrary . Net ;
10+ using TechnitiumLibrary . Net . Dns ;
11+ using TechnitiumLibrary . Net . Dns . ResourceRecords ;
12+
13+ namespace SourceFilterApp ;
14+
15+ public sealed class App : IDnsApplication , IDnsPostProcessor
16+ {
17+ #region IDisposable
18+
19+ public void Dispose ( ) { }
20+
21+ #endregion
22+
23+ #region properties
24+
25+ public string Description => "Filters answer records by client network according to include/exclude rules and optional splitNetworks." ;
26+
27+ #endregion
28+
29+ #region private
30+
31+ private Rule GetRule ( string name )
32+ {
33+ Rule best = null ;
34+ var bestScore = - 1 ;
35+
36+ foreach ( var rule in this . rules )
37+ {
38+ var score = rule . Match ( name ) ;
39+
40+ if ( score <= bestScore )
41+ continue ;
42+ bestScore = score ;
43+ best = rule ;
44+ }
45+
46+ return best ;
47+ }
48+
49+ #endregion
50+
51+ #region variables
52+
53+ private bool enabled ;
54+ private Rule [ ] rules ;
55+
56+ #endregion
57+
58+ #region public
59+
60+ public Task InitializeAsync ( IDnsServer dnsServer , string config )
61+ {
62+ var list = new List < Rule > ( ) ;
63+
64+ if ( string . IsNullOrEmpty ( config ) )
65+ {
66+ this . enabled = false ;
67+
68+ return Task . CompletedTask ;
69+ }
70+
71+ using ( var json = JsonDocument . Parse ( config ) )
72+ {
73+ var root = json . RootElement ;
74+ this . enabled = ! root . TryGetProperty ( "enabled" , out var jsonEnabled ) || jsonEnabled . GetBoolean ( ) ;
75+
76+ if ( root . TryGetProperty ( "rules" , out var jsonRules ) && jsonRules . ValueKind == JsonValueKind . Array )
77+ foreach ( var jsonRule in jsonRules . EnumerateArray ( ) )
78+ list . Add ( new ( jsonRule ) ) ;
79+ else
80+ foreach ( var prop in root . EnumerateObject ( ) . Where ( prop => ! prop . NameEquals ( "enabled" ) ) )
81+ list . Add ( new ( prop . Name , prop . Value ) ) ;
82+ }
83+
84+ this . rules = list . Count == 0 ? [ ] : list . ToArray ( ) ;
85+
86+ return Task . CompletedTask ;
87+ }
88+
89+ public Task < DnsDatagram > PostProcessAsync ( DnsDatagram request , IPEndPoint remoteEP , DnsTransportProtocol protocol , DnsDatagram response )
90+ {
91+ if ( ! this . enabled )
92+ return Task . FromResult ( response ) ;
93+
94+ if ( response . Answer . Count == 0 )
95+ return Task . FromResult ( response ) ;
96+
97+ var clientIp = remoteEP . Address ;
98+ var answer = new List < DnsResourceRecord > ( response . Answer . Count ) ;
99+
100+ foreach ( var record in response . Answer )
101+ {
102+ var rule = this . GetRule ( record . Name ) ;
103+ if ( rule is null )
104+ {
105+ answer . Add ( record ) ;
106+
107+ continue ;
108+ }
109+
110+ if ( ! rule . IsClientAllowed ( clientIp ) )
111+ continue ;
112+
113+ if ( rule . PassesSplit ( clientIp , record ) )
114+ answer . Add ( record ) ;
115+ }
116+
117+ if ( answer . Count == response . Answer . Count )
118+ return Task . FromResult ( response ) ;
119+
120+ if ( answer . Count == 0 )
121+ return Task . FromResult ( response . Clone ( [ ] ) ) ;
122+
123+ return Task . FromResult ( response . Clone ( answer ) ) ;
124+ }
125+
126+ #endregion
127+
128+ #region inner
129+
130+ private sealed class Rule
131+ {
132+ private readonly NetworkSet exclude ;
133+ private readonly NetworkSet include ;
134+ private readonly string pattern ;
135+ private readonly int specificity ;
136+ private readonly NetworkSet split ;
137+ private readonly bool wildcard ;
138+
139+ public Rule ( JsonElement json ) : this (
140+ ( json . TryGetProperty ( "pattern" , out var jsonPattern )
141+ ? jsonPattern . ValueKind == JsonValueKind . String ? jsonPattern . GetString ( ) : jsonPattern . ToString ( )
142+ : null )
143+ ?? "*" ,
144+ json ) { }
145+
146+ public Rule ( string pattern , JsonElement jsonRule )
147+ {
148+ this . pattern = Normalize ( pattern ) ;
149+ this . wildcard = this . pattern == "*" || this . pattern . StartsWith ( "*." ) ;
150+ this . specificity = this . wildcard
151+ ? this . pattern == "*" ? 0 : this . pattern . Length - 2
152+ : this . pattern . Length ;
153+
154+ this . include = new ( GetNetworks ( jsonRule , true , "includeNetworks" , "include" ) ) ;
155+ this . exclude = new ( GetNetworks ( jsonRule , false , "excludeNetworks" , "exclude" ) ) ;
156+ this . split = new ( GetNetworks ( jsonRule , false , "splitNetworks" ) ) ;
157+ }
158+
159+ private static List < NetworkAddress > GetNetworks ( JsonElement json , bool addDefault , params string [ ] names )
160+ {
161+ var list = new List < NetworkAddress > ( ) ;
162+
163+ foreach ( var n in names )
164+ {
165+ if ( ! json . TryGetProperty ( n , out var value ) || value . ValueKind != JsonValueKind . Array )
166+ continue ;
167+
168+ foreach ( var str in value . EnumerateArray ( ) . Select ( x => x . GetString ( ) ) )
169+ if ( NetworkAddress . TryParse ( str , out var addr ) )
170+ list . Add ( addr ) ;
171+ }
172+
173+ if ( addDefault && list . Count == 0 )
174+ {
175+ list . Add ( NetworkAddress . Parse ( "0.0.0.0/0" ) ) ;
176+ list . Add ( NetworkAddress . Parse ( "::/0" ) ) ;
177+ }
178+
179+ return list ;
180+ }
181+
182+ public int Match ( string name )
183+ {
184+ name = Normalize ( name ) ;
185+
186+ if ( this . pattern == "*" )
187+ return 0 ;
188+
189+ if ( this . wildcard )
190+ {
191+ if ( ! name . EndsWith ( this . pattern [ 1 ..] , StringComparison . OrdinalIgnoreCase ) )
192+ return - 1 ;
193+ if ( name . Length == this . specificity )
194+ return - 1 ;
195+
196+ return this . specificity ;
197+ }
198+
199+ return name . Equals ( this . pattern , StringComparison . OrdinalIgnoreCase )
200+ ? this . specificity
201+ : - 1 ;
202+ }
203+
204+ public bool IsClientAllowed ( IPAddress clientIp )
205+ {
206+ if ( ! this . include . Contains ( clientIp ) )
207+ return false ;
208+ if ( this . exclude . Contains ( clientIp ) )
209+ return false ;
210+
211+ return true ;
212+ }
213+
214+ public bool PassesSplit ( IPAddress clientIp , DnsResourceRecord record )
215+ {
216+ if ( this . split . IsEmpty )
217+ return true ;
218+
219+ var recordIp = record switch
220+ {
221+ { Type : DnsResourceRecordType . A , RDATA : DnsARecordData a } => a . Address ,
222+ { Type : DnsResourceRecordType . AAAA , RDATA : DnsAAAARecordData aaaa } => aaaa . Address ,
223+ _ => null
224+ } ;
225+
226+ if ( recordIp is null )
227+ return true ;
228+
229+ var clientInside = this . split . Contains ( clientIp ) ;
230+ var recordInside = this . split . Contains ( recordIp ) ;
231+
232+ return clientInside == recordInside ;
233+ }
234+ }
235+
236+ private sealed class NetworkSet
237+ {
238+ private readonly NetworkAddress [ ] nets ;
239+
240+ public NetworkSet ( IReadOnlyList < NetworkAddress > nets ) => this . nets = nets . Count == 0 ? [ ] : nets . ToArray ( ) ;
241+
242+ public bool IsEmpty => this . nets . Length == 0 ;
243+
244+ public bool Contains ( IPAddress ip )
245+ {
246+ foreach ( var net in this . nets )
247+ if ( net . Contains ( ip ) )
248+ return true ;
249+
250+ return false ;
251+ }
252+ }
253+
254+ private static readonly IdnMapping idn = new ( ) ;
255+
256+ private static string Normalize ( string s )
257+ {
258+ s = s . TrimEnd ( '.' ) ;
259+
260+ if ( s == "*" )
261+ return s . ToLowerInvariant ( ) ;
262+ if ( s . StartsWith ( "*." ) )
263+ return "*." + idn . GetAscii ( s [ 2 ..] ) . ToLowerInvariant ( ) ;
264+
265+ return idn . GetAscii ( s ) . ToLowerInvariant ( ) ;
266+ }
267+
268+ #endregion
269+ }
0 commit comments