Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 126 additions & 74 deletions Apps/MispConnectorApp/App.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ You should have received a copy of the GNU General Public License
using System.ComponentModel.DataAnnotations;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
Expand Down Expand Up @@ -53,6 +52,7 @@ public sealed class App : IDnsApplication, IDnsRequestBlockingHandler
HttpClient _httpClient;

Uri _mispApiUrl;
Uri _mispServerUrl;

DnsSOARecordData _soaRecord;
TimeSpan _updateInterval;
Expand All @@ -70,7 +70,13 @@ public void Dispose()
{
if (_updateLoopTask != null)
{
_ = Task.WhenAny(_updateLoopTask, Task.Delay(TimeSpan.FromSeconds(2))).GetAwaiter().GetResult();
try
{
_updateLoopTask?.WaitAsync(TimeSpan.FromSeconds(2))
.GetAwaiter()
.GetResult();
}
catch { }
}
}
catch
Expand Down Expand Up @@ -105,9 +111,9 @@ public async Task InitializeAsync(IDnsServer dnsServer, string config)

_updateInterval = ParseUpdateInterval(_config.UpdateInterval);

Uri mispServerUrl = new Uri(_config.MispServerUrl);
_mispApiUrl = new Uri(mispServerUrl, "/attributes/restSearch");
_httpClient = CreateHttpClient(mispServerUrl, _config.DisableTlsValidation);
_mispServerUrl = new Uri(_config.MispServerUrl);
_mispApiUrl = new Uri(_mispServerUrl, "/attributes/restSearch");
_httpClient = CreateHttpClient(_mispServerUrl, _config.DisableTlsValidation);

await LoadBlocklistFromCacheAsync();
_appShutdownCts = new CancellationTokenSource();
Expand All @@ -130,6 +136,8 @@ public async Task InitializeAsync(IDnsServer dnsServer, string config)
}
}

// No allowlist override in this app.
// ProcessRequestAsync handles blocking.
public Task<bool> IsAllowedAsync(DnsDatagram request, IPEndPoint remoteEP)
{
return Task.FromResult(false);
Expand All @@ -151,50 +159,47 @@ public Task<DnsDatagram> ProcessRequestAsync(DnsDatagram request, IPEndPoint rem

string blockingReport = $"source=misp-connector;domain={blockedDomain}";

// Add blocking report as EDE to EDNS options for both TXT and other queries if the query datagram has EDNS field
EDnsOption[] options = null;
if (_config.AddExtendedDnsError && request.EDNS is not null)
{
options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, string.Empty)) };
options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, blockingReport)) };
}

DnsResourceRecord[] answer = null;
DnsResourceRecord[] authority = null;
bool authoritative = false;
DnsResponseCode rCode;
if (_config.AllowTxtBlockingReport && question.Type == DnsResourceRecordType.TXT)
{
DnsResourceRecord[] answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData(string.Empty)) };
return Task.FromResult(new DnsDatagram(
ID: request.Identifier,
isResponse: true,
OPCODE: DnsOpcode.StandardQuery,
authoritativeAnswer: false,
truncation: false,
recursionDesired: request.RecursionDesired,
recursionAvailable: true,
authenticData: false,
checkingDisabled: false,
RCODE: DnsResponseCode.NoError,
question: request.Question,
answer: answer,
authority: null,
additional: null,
udpPayloadSize: request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize,
ednsFlags: EDnsHeaderFlags.None,
options: options
));
answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, _config.BlockingAnswerTtl, new DnsTXTRecordData(blockingReport)) };
rCode = DnsResponseCode.NoError;
}
else
{
authority = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.SOA, question.Class, _config.BlockingAnswerTtl, _soaRecord) };
rCode = DnsResponseCode.NxDomain;
authoritative = true;
}

return BlockResponse(request: request, options: options, authority: authority, answer: answer, authoritativeAnswer: authoritative, rCode: rCode);
}

DnsResourceRecord[] authority = { new DnsResourceRecord(question.Name, DnsResourceRecordType.SOA, question.Class, 60, _soaRecord) };
private Task<DnsDatagram> BlockResponse(DnsDatagram request, EDnsOption[] options, DnsResourceRecord[] authority, DnsResourceRecord[] answer, bool authoritativeAnswer, DnsResponseCode rCode)
{
return Task.FromResult(new DnsDatagram(
ID: request.Identifier,
isResponse: true,
OPCODE: DnsOpcode.StandardQuery,
authoritativeAnswer: true,
authoritativeAnswer: authoritativeAnswer,
truncation: false,
recursionDesired: request.RecursionDesired,
recursionAvailable: true,
authenticData: false,
checkingDisabled: false,
RCODE: DnsResponseCode.NxDomain,
RCODE: rCode,
question: request.Question,
answer: null,
answer: answer,
authority: authority,
additional: null,
udpPayloadSize: request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize,
Expand All @@ -211,24 +216,26 @@ private async Task StartUpdateLoopAsync(CancellationToken cancellationToken)
await Task.Delay(TimeSpan.FromSeconds(Random.Shared.Next(5, 30)), cancellationToken);
using (PeriodicTimer timer = new PeriodicTimer(_updateInterval))
{
while (!cancellationToken.IsCancellationRequested)
while (true)
{
try
{
await UpdateIocsAsync(cancellationToken);
}
catch (OperationCanceledException)
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
_dnsServer.WriteLog("Update loop is shutting down gracefully.");
break;
}

catch (Exception ex)
{
_dnsServer.WriteLog($"FATAL: The MispConnector update task failed unexpectedly. Error: {ex.Message}");
_dnsServer.WriteLog(ex);
}

await timer.WaitForNextTickAsync(cancellationToken);
if (!await timer.WaitForNextTickAsync(cancellationToken))
break;
}
}
}
Expand Down Expand Up @@ -273,15 +280,19 @@ private async Task<bool> CheckTcpPortAsync(Uri serverUri, CancellationToken canc

try
{
using (CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, new CancellationTokenSource(timeout).Token))
using (TcpClient client = new TcpClient())
{
await client.ConnectAsync(host, port, cts.Token);
}
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(timeout);

using var client = new TcpClient();
await client.ConnectAsync(host, port, cts.Token);

_dnsServer.WriteLog($"Pre-flight TCP check successful for {host}:{port}.");
return true;
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
throw;
}
catch (OperationCanceledException)
{
_dnsServer.WriteLog($"ERROR: Pre-flight TCP check failed: Connection to {host}:{port} timed out after {timeout.TotalSeconds} seconds. Check firewall rules or network route.");
Expand Down Expand Up @@ -370,41 +381,39 @@ private async Task<HashSet<string>> FetchIocFromMispAsync(CancellationToken canc

break;
}
catch (Exception ex) when (ex is HttpRequestException || ex is SocketException || ex is OperationCanceledException)
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// These are likely transient network errors, so we should retry.
_dnsServer.WriteLog($"WARNING: A transient network error occurred on page {page}, attempt {attempt}/{maxRetries}. Error: {ex.Message}");
if (attempt < maxRetries)
{
TimeSpan delay = TimeSpan.FromSeconds(Math.Pow(2, attempt)) + TimeSpan.FromMilliseconds(Random.Shared.Next(0, 1000));
_dnsServer.WriteLog($"Waiting for {delay.TotalSeconds:F1} seconds before retrying...");
await Task.Delay(delay, cancellationToken);
}
else
{
// All retries have failed for this page.
_dnsServer.WriteLog($"ERROR: Failed to fetch page {page} after {maxRetries} attempts. Aborting entire update cycle.");
throw;
}
catch (HttpRequestException ex)
{
if (!await HandleRetry(ex, page, maxRetries, attempt, cancellationToken))
throw;
}
catch (SocketException ex)
{
if (!await HandleRetry(ex, page, maxRetries, attempt, cancellationToken))
throw;
}
}

}

List<MispAttribute> attributes = mispResponse?.Response?.Attribute;
if (attributes == null || attributes.Count == 0)
List<MispAttribute> attributes = (mispResponse?.Response?.Attribute) ??
throw new InvalidDataException("Invalid or unexpected MISP response schema.");

if (attributes.Count == 0)
{
hasMorePages = false;
continue;
}

foreach (MispAttribute attribute in attributes)
{
string ioc = attribute.Value?.Trim().ToLowerInvariant();
if (!string.IsNullOrEmpty(ioc))
string ioc = attribute.Value?.Trim();

if (!string.IsNullOrEmpty(ioc) && DnsClient.IsDomainNameValid(ioc))
{
if (DnsClient.IsDomainNameValid(ioc))
{
iocSet.Add(ioc);
}
iocSet.Add(ioc);
}
}

Expand All @@ -423,28 +432,55 @@ private async Task<HashSet<string>> FetchIocFromMispAsync(CancellationToken canc
return iocSet;
}

private async Task<bool> HandleRetry(
Exception ex,
int page,
int maxRetries,
int attempt,
CancellationToken cancellationToken)
{
_dnsServer.WriteLog(
$"WARNING: A transient network error occurred on page {page}, " +
$"attempt {attempt}/{maxRetries}. Error: {ex.Message}");

if (attempt < maxRetries)
{
TimeSpan delay =
TimeSpan.FromSeconds(Math.Pow(2, attempt)) +
TimeSpan.FromMilliseconds(Random.Shared.Next(0, 1000));

_dnsServer.WriteLog(
$"Waiting for {delay.TotalSeconds:F1} seconds before retrying...");

await Task.Delay(delay, cancellationToken);
return true; // retry
}

_dnsServer.WriteLog(
$"ERROR: Failed to fetch page {page} after {maxRetries} attempts.");

return false; // abort
}


private bool IsDomainBlocked(string domain, out string foundZone)
{
FrozenSet<string> currentBlocklist = _domainBlocklist;

// Span-based lookup
FrozenSet<string>.AlternateLookup<ReadOnlySpan<char>> lookup = currentBlocklist.GetAlternateLookup<ReadOnlySpan<char>>();

ReadOnlySpan<char> currentSpan = domain.AsSpan();

while (true)
{
// To look up in a HashSet<string>, we must provide a string.
string key = new string(currentSpan);
if (currentBlocklist.TryGetValue(key, out foundZone))
{
if (lookup.TryGetValue(currentSpan, out foundZone))
return true;
}

int dotIndex = currentSpan.IndexOf('.');
if (dotIndex == -1)
{
break; // No more parent domains.
}
if (dotIndex < 0)
break;

// Slice to the parent domain view. No allocation here.
currentSpan = currentSpan.Slice(dotIndex + 1);
}

Expand All @@ -458,7 +494,17 @@ private async Task LoadBlocklistFromCacheAsync()
{
try
{
FrozenSet<string> domains = (await File.ReadAllLinesAsync(_domainCacheFilePath)).ToHashSet(StringComparer.OrdinalIgnoreCase).ToFrozenSet(StringComparer.OrdinalIgnoreCase);
var set = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

await foreach (var line in File.ReadLinesAsync(_domainCacheFilePath))
{
var d = line.Trim();

if (d.Length > 0)
set.Add(d);
}

FrozenSet<string> domains = set.ToFrozenSet(StringComparer.OrdinalIgnoreCase);
Interlocked.Exchange(ref _domainBlocklist, domains);
_dnsServer.WriteLog($"MISP Connector: Loaded {domains.Count} domains from cache.");
}
Expand All @@ -471,7 +517,7 @@ private async Task LoadBlocklistFromCacheAsync()

private async Task UpdateIocsAsync(CancellationToken cancellationToken)
{
if (!await CheckTcpPortAsync(new Uri(_config.MispServerUrl), cancellationToken))
if (!await CheckTcpPortAsync(_mispServerUrl, cancellationToken))
{
return;
}
Expand All @@ -486,11 +532,11 @@ private async Task UpdateIocsAsync(CancellationToken cancellationToken)
{
await WriteIocsToCacheAsync(domains, cancellationToken);
Interlocked.Exchange(ref _domainBlocklist, domains);
_dnsServer.WriteLog($"MISP Connector: Successfully updated blocklist with {domains.Count} domains.");
_dnsServer.WriteLog($"MISP Connector: Successfully updated currentBlocklist with {domains.Count} domains.");
}
else
{
_dnsServer.WriteLog("MISP data has not changed. No update to blocklist or cache is necessary.");
_dnsServer.WriteLog("MISP data has not changed. No update to currentBlocklist or cache is necessary.");
}
}

Expand Down Expand Up @@ -528,11 +574,16 @@ private class Config

[JsonPropertyName("enableBlocking")]
public bool EnableBlocking { get; set; } = true;

[JsonPropertyName("maxIocAge")]
[Required(ErrorMessage = "maxIocAge is a required configuration property.")]
[RegularExpression(@"^\d+[mhd]$", ErrorMessage = "Invalid interval format. Use a number followed by 'm', 'h', or 'd' (e.g., '90m', '2h', '7d').", MatchTimeoutInMilliseconds = 3000)]
public string MaxIocAge { get; set; }

[JsonPropertyName("blockingAnswerTtl")]
[Range(30, 86400, ErrorMessage = "blockingAnswerTtl must be between 30 and 86400 seconds.")]
public uint BlockingAnswerTtl { get; set; } = 30;

[JsonPropertyName("mispApiKey")]
[Required(ErrorMessage = "mispApiKey is a required configuration property.")]
[MinLength(1, ErrorMessage = "mispApiKey cannot be empty.")]
Expand All @@ -542,6 +593,7 @@ private class Config
[Required(ErrorMessage = "mispServerUrl is a required configuration property.")]
[Url(ErrorMessage = "mispServerUrl must be a valid URL.")]
public string MispServerUrl { get; set; }

[JsonPropertyName("paginationLimit")]
public int PaginationLimit { get; set; } = 5000;

Expand Down
2 changes: 1 addition & 1 deletion Apps/MispConnectorApp/MispConnectorApp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>net9.0</TargetFramework>
<AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
<Version>1.0</Version>
<Version>1.1</Version>
<IncludeSourceRevisionInInformationalVersion>false</IncludeSourceRevisionInInformationalVersion>
<Company>Technitium</Company>
<Product>Technitium DNS Server</Product>
Expand Down
Loading