Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# MasterDnsVPN Client
# MasterDnsVPN Client
# Author: MasterkinG32
# Github: https://github.com/masterking32
# Year: 2026
Expand All @@ -17,6 +17,7 @@
import time
from bisect import bisect_left, bisect_right, insort
from collections import defaultdict, deque
from typing import Optional

from dns_utils.ARQ import ARQ
from dns_utils.compression import (
Expand Down Expand Up @@ -63,11 +64,11 @@ def __init__(self) -> None:
# ---------------------------------------------------------
# Runtime and lifecycle primitives
# ---------------------------------------------------------
self.loop: asyncio.AbstractEventLoop | None = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.should_stop: asyncio.Event = asyncio.Event()
self.session_restart_event = None
self.rx_tasks = set()
self.cpu_executor: concurrent.futures.ThreadPoolExecutor | None = None
self.cpu_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None

# ---------------------------------------------------------
# Config and logger bootstrap
Expand Down Expand Up @@ -224,7 +225,7 @@ def __init__(self) -> None:
# ---------------------------------------------------------
self.base_encode_responses: bool = self.config.get("BASE_ENCODE_DATA", False)
self.encryption_method: int = self.config.get("DATA_ENCRYPTION_METHOD", 1)
self.encryption_key: str | None = self.config.get("ENCRYPTION_KEY", None)
self.encryption_key: Optional[str] = self.config.get("ENCRYPTION_KEY", None)

if not self.encryption_key:
self.logger.error(
Expand Down Expand Up @@ -503,7 +504,7 @@ def _schedule_recheck_after_failure(
def _format_mtu_log_line(
self,
template: str,
connection: dict | None = None,
connection: Optional[dict] = None,
cause: str = "",
) -> str:
if not template:
Expand Down Expand Up @@ -545,7 +546,7 @@ def _format_mtu_log_line(
def _append_mtu_log_line(
self,
template: str,
connection: dict | None = None,
connection: Optional[dict] = None,
cause: str = "",
output_path: str = "",
) -> None:
Expand Down Expand Up @@ -728,7 +729,7 @@ async def _send_and_receive_dns(
port: int,
timeout: float = 10,
buffer_size: int = 0,
) -> bytes | None:
) -> Optional[bytes]:
"""Send a UDP packet and wait for the response."""
buf_size = buffer_size or self.buffer_size

Expand Down Expand Up @@ -809,7 +810,7 @@ def _get_active_response_queue(self, stream_id: int):
self._deactivate_response_queue(sid)
return None, None

def _match_allowed_domain_suffix(self, qname: str) -> str | None:
def _match_allowed_domain_suffix(self, qname: str) -> Optional[str]:
"""Return the matched allowed domain suffix for qname, if any."""
if not qname:
return None
Expand Down Expand Up @@ -853,9 +854,7 @@ def _apply_session_compression_policy(self) -> None:
f"<cyan>[Compression]</cyan> <green>Effective Compression - Upload: <yellow>{get_compression_name(up)}</yellow>, Download: <yellow>{get_compression_name(down)}</yellow></green>"
)

async def _process_received_packet(
self, response_bytes: bytes, addr=None
) -> tuple[dict | None, bytes]:
async def _process_received_packet(self, response_bytes, addr=None):
"""Parse DNS response, validate source/domain once, then extract VPN payload."""
if not response_bytes:
return None, b""
Expand Down Expand Up @@ -3876,7 +3875,7 @@ def main():
pass
else:
try:
import uvloop # pylint: disable=import-outside-toplevel
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
Expand Down Expand Up @@ -3922,7 +3921,7 @@ def custom_exception_handler(loop, context):
# On Windows, register a Console Ctrl Handler early so Ctrl+C is handled
if sys.platform == "win32":
try:
from ctypes import wintypes # pylint: disable=import-outside-toplevel
from ctypes import wintypes

HandlerRoutine = ctypes.WINFUNCTYPE(wintypes.BOOL, wintypes.DWORD)

Expand Down
10 changes: 5 additions & 5 deletions dns_utils/DnsPacketParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import random
import struct
from typing import Any
from typing import Any, Optional

from .DNS_ENUMS import DNS_QClass, DNS_Record_Type, Packet_Type, DNS_rCode

Expand Down Expand Up @@ -167,7 +167,7 @@ class DnsPacketParser:

def __init__(
self,
logger: Any | None = None,
logger: Optional[Any] = None,
encryption_key: str = "",
encryption_method: int = 1,
):
Expand All @@ -194,7 +194,7 @@ def __init__(

if self.encryption_method in (3, 4, 5):
try:
from cryptography.hazmat.primitives.ciphers.aead import AESGCM # pylint: disable=import-outside-toplevel
from cryptography.hazmat.primitives.ciphers.aead import AESGCM

self._aesgcm = AESGCM(self.key)
except ImportError:
Expand All @@ -203,8 +203,8 @@ def __init__(

elif self.encryption_method == 2:
try:
from cryptography.hazmat.backends import default_backend # pylint: disable=import-outside-toplevel
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms # pylint: disable=import-outside-toplevel
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms

self._Cipher = Cipher
self._default_backend = default_backend
Expand Down
3 changes: 2 additions & 1 deletion dns_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Year: 2026
from loguru import logger
import sys
from typing import Optional
import secrets
import asyncio
import socket
Expand Down Expand Up @@ -117,7 +118,7 @@ def cb():
raise


def load_text(file_path: str) -> str | None:
def load_text(file_path: str) -> Optional[str]:
"""
Load and return the contents of a text file, stripped of leading/trailing whitespace.
Returns None if the file does not exist or error occurs.
Expand Down
39 changes: 20 additions & 19 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time
from bisect import bisect_left, bisect_right, insort
from collections import deque
from typing import Optional

from dns_utils.ARQ import ARQ
from dns_utils.compression import (
Expand Down Expand Up @@ -68,8 +69,8 @@ def __init__(self) -> None:
# ---------------------------------------------------------
# Runtime primitives
# ---------------------------------------------------------
self.udp_sock: socket.socket | None = None
self.loop: asyncio.AbstractEventLoop | None = None
self.udp_sock: Optional[socket.socket] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.should_stop = asyncio.Event()

# ---------------------------------------------------------
Expand Down Expand Up @@ -184,7 +185,7 @@ def __init__(self) -> None:
self._dns_task = None
self._session_cleanup_task = None
self._background_tasks = set()
self.cpu_executor: concurrent.futures.ThreadPoolExecutor | None = None
self.cpu_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
auto_cpu_workers = max(2, min(16, (os.cpu_count() or 1)))
raw_cpu_workers = int(self.config.get("CPU_WORKER_THREADS", 0))
if raw_cpu_workers < 0:
Expand Down Expand Up @@ -333,7 +334,7 @@ def __init__(self) -> None:
f"Please update your config file to the latest version ({self.min_config_version}) for best performance and new features."
)

def _parse_compression_value(self, value) -> int | None:
def _parse_compression_value(self, value) -> Optional[int]:
if isinstance(value, str):
v = value.strip()
if not v:
Expand Down Expand Up @@ -421,7 +422,7 @@ async def new_session(
client_token: bytes = b"",
client_upload_compression_type: int = 0,
client_download_compression_type: int = 0,
) -> int | None:
) -> Optional[int]:
try:
if not self.free_session_ids:
self.logger.error(
Expand Down Expand Up @@ -572,7 +573,7 @@ def _deactivate_response_queue(self, session: dict, stream_id: int) -> None:
active_ids.pop(idx)

def _extract_packet_payload(
self, labels: str, extracted_header: dict | None
self, labels: str, extracted_header: Optional[dict]
) -> bytes:
"""Extract packet payload and apply optional decompression based on header flag."""
try:
Expand Down Expand Up @@ -638,7 +639,7 @@ async def _handle_session_init(
parsed_packet=None,
session_id=None,
extracted_header=None,
) -> bytes | None:
) -> Optional[bytes]:
"""Handle NEW_SESSION VPN packet."""
try:
client_payload = self._extract_packet_payload(labels, extracted_header)
Expand Down Expand Up @@ -1586,8 +1587,8 @@ async def _handle_pre_session_packet(
data: bytes,
labels: str,
request_domain: str,
extracted_header: dict | None = None,
) -> bytes | None:
extracted_header: Optional[dict] = None,
) -> Optional[bytes]:
if packet_type == Packet_Type.SESSION_INIT:
return await self._handle_session_init(
request_domain=request_domain,
Expand Down Expand Up @@ -1619,7 +1620,7 @@ async def _process_session_packet(
stream_id: int,
sn: int,
labels: str,
extracted_header: dict | None,
extracted_header: Optional[dict],
now_mono: float,
) -> None:
"""Process a session packet without blocking response generation."""
Expand Down Expand Up @@ -1817,7 +1818,7 @@ def _build_invalid_session_error_response(
session_id: int,
request_domain: str,
question_packet: bytes,
closed_info: dict | None,
closed_info: Optional[dict],
) -> bytes:
try:
is_base = (
Expand Down Expand Up @@ -1847,11 +1848,11 @@ async def handle_vpn_packet(
session_id: int,
data: bytes = b"",
labels: str = "",
parsed_packet: dict | None = None,
parsed_packet: Optional[dict] = None,
addr=None,
request_domain: str = "",
extracted_header: dict | None = None,
) -> bytes | None:
extracted_header: Optional[dict] = None,
) -> Optional[bytes]:
# First handle packets that don't require an active session (e.g. session init, MTU negotiation).
if packet_type in self._pre_session_packet_types:
pre_session_response = await self._handle_pre_session_packet(
Expand Down Expand Up @@ -2358,7 +2359,7 @@ async def _handle_set_mtu(
parsed_packet=None,
session_id=None,
extracted_header=None,
) -> bytes | None:
) -> Optional[bytes]:
"""Handle SET_MTU_REQ VPN packet and save it to the session."""
try:
session = self.sessions.get(session_id)
Expand Down Expand Up @@ -2423,7 +2424,7 @@ async def _handle_mtu_down(
parsed_packet=None,
session_id=None,
extracted_header=None,
) -> bytes | None:
) -> Optional[bytes]:
"""Handle MTU_DOWN_REQ (download MTU test) VPN packet."""
try:
download_size_bytes = self._extract_packet_payload(labels, extracted_header)
Expand Down Expand Up @@ -2472,7 +2473,7 @@ async def _handle_mtu_up(
parsed_packet=None,
session_id=None,
extracted_header=None,
) -> bytes | None:
) -> Optional[bytes]:
"""Handle SERVER_UPLOAD_TEST VPN packet."""
try:
raw_label = labels.split(".")[0] if "." in labels else labels
Expand Down Expand Up @@ -3148,7 +3149,7 @@ def main():
pass
else:
try:
import uvloop # pylint: disable=import-outside-toplevel
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
Expand Down Expand Up @@ -3193,7 +3194,7 @@ def custom_exception_handler(loop, context):

if sys.platform == "win32":
try:
from ctypes import wintypes # pylint: disable=import-outside-toplevel
from ctypes import wintypes

HandlerRoutine = ctypes.WINFUNCTYPE(wintypes.BOOL, wintypes.DWORD)

Expand Down