Skip to content

Commit 4495016

Browse files
authored
Add files via upload
1 parent 00bee51 commit 4495016

1 file changed

Lines changed: 377 additions & 0 deletions

File tree

dnsd.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
from __future__ import print_function
4+
5+
import socket
6+
import struct
7+
import threading
8+
import time
9+
import random
10+
import ssl
11+
12+
try:
13+
# Py3
14+
from urllib.request import Request, urlopen
15+
from urllib.parse import urlencode
16+
except ImportError:
17+
# Py2
18+
from urllib2 import Request, urlopen
19+
from urllib import urlencode
20+
21+
# -----------------------------
22+
# Root server IPs (a subset)
23+
# -----------------------------
24+
ROOT_SERVERS = [
25+
"198.41.0.4", # a.root-servers.net
26+
"199.9.14.201", # b.root-servers.net
27+
"192.33.4.12", # c.root-servers.net
28+
"199.7.91.13", # d.root-servers.net
29+
"192.203.230.10", # e.root-servers.net
30+
]
31+
32+
QTYPE = {"A": 1, "NS": 2, "CNAME": 5, "MX": 15, "TXT": 16, "AAAA": 28}
33+
QCLASS_IN = 1
34+
35+
def _bchr(x):
36+
# py2/3
37+
return bytes(bytearray([x]))
38+
39+
def _byte_at(b, i):
40+
v = b[i]
41+
return v if isinstance(v, int) else ord(v)
42+
43+
def encode_qname(name):
44+
name = name.strip().rstrip('.')
45+
if not name:
46+
return b"\x00"
47+
out = []
48+
for part in name.split('.'):
49+
pb = part.encode("utf-8") if not isinstance(part, bytes) else part
50+
out.append(struct.pack("!B", len(pb)) + pb)
51+
out.append(b"\x00")
52+
return b"".join(out)
53+
54+
def decode_name(msg, off):
55+
labels = []
56+
jumped = False
57+
orig = off
58+
seen = set()
59+
while True:
60+
if off >= len(msg):
61+
raise ValueError("name out of bounds")
62+
if off in seen:
63+
raise ValueError("compression loop")
64+
seen.add(off)
65+
66+
ln = _byte_at(msg, off)
67+
if ln == 0:
68+
off += 1
69+
break
70+
if (ln & 0xC0) == 0xC0:
71+
b2 = _byte_at(msg, off + 1)
72+
ptr = ((ln & 0x3F) << 8) | b2
73+
if not jumped:
74+
orig = off + 2
75+
jumped = True
76+
off = ptr
77+
continue
78+
79+
off += 1
80+
lab = msg[off:off+ln]
81+
try:
82+
labels.append(lab.decode("utf-8"))
83+
except Exception:
84+
# best-effort
85+
labels.append("".join(chr(_byte_at(lab, i)) for i in range(len(lab))))
86+
off += ln
87+
88+
return ".".join(labels), (orig if jumped else off)
89+
90+
def build_query(qname, qtype, rd=False, dnssec_do=False):
91+
tid = random.randint(0, 0xFFFF)
92+
flags = 0x0000
93+
if rd:
94+
flags |= 0x0100 # RD
95+
# We are not implementing EDNS0 fully here; dnssec_do placeholder is kept.
96+
header = struct.pack("!HHHHHH", tid, flags, 1, 0, 0, 0)
97+
q = encode_qname(qname) + struct.pack("!HH", qtype, QCLASS_IN)
98+
return tid, header + q
99+
100+
def parse_header(msg):
101+
if len(msg) < 12:
102+
raise ValueError("short header")
103+
tid, flags, qd, an, ns, ar = struct.unpack("!HHHHHH", msg[:12])
104+
tc = bool(flags & 0x0200)
105+
rcode = flags & 0x000F
106+
return tid, flags, qd, an, ns, ar, tc, rcode
107+
108+
def skip_questions(msg, off, qd):
109+
for _ in range(qd):
110+
_, off = decode_name(msg, off)
111+
off += 4
112+
return off
113+
114+
def parse_rr(msg, off):
115+
name, off = decode_name(msg, off)
116+
rtype, rclass, ttl, rdlen = struct.unpack("!HHIH", msg[off:off+10])
117+
off += 10
118+
rdata_off = off
119+
rdata = msg[off:off+rdlen]
120+
off += rdlen
121+
return {
122+
"name": name, "type": rtype, "class": rclass, "ttl": ttl,
123+
"rdlen": rdlen, "rdata": rdata, "rdata_off": rdata_off
124+
}, off
125+
126+
def parse_sections(msg):
127+
tid, flags, qd, an, ns, ar, tc, rcode = parse_header(msg)
128+
off = 12
129+
off = skip_questions(msg, off, qd)
130+
131+
answers = []
132+
for _ in range(an):
133+
rr, off = parse_rr(msg, off)
134+
answers.append(rr)
135+
136+
authority = []
137+
for _ in range(ns):
138+
rr, off = parse_rr(msg, off)
139+
authority.append(rr)
140+
141+
additional = []
142+
for _ in range(ar):
143+
rr, off = parse_rr(msg, off)
144+
additional.append(rr)
145+
146+
return {
147+
"tid": tid, "flags": flags, "rcode": rcode, "tc": tc,
148+
"answers": answers, "authority": authority, "additional": additional,
149+
"raw": msg
150+
}
151+
152+
def rr_ip_from_additional(rr):
153+
# A or AAAA glue
154+
if rr["type"] == QTYPE["A"] and rr["rdlen"] == 4:
155+
b = bytearray(rr["rdata"])
156+
return "%d.%d.%d.%d" % (b[0], b[1], b[2], b[3])
157+
if rr["type"] == QTYPE["AAAA"] and rr["rdlen"] == 16:
158+
# simple hextets (no compression)
159+
b = bytearray(rr["rdata"])
160+
hextets = []
161+
for i in range(0, 16, 2):
162+
hextets.append("%02x%02x" % (b[i], b[i+1]))
163+
return ":".join(hextets)
164+
return None
165+
166+
def udp_exchange(server_ip, wire_query, timeout=2, port=53):
167+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
168+
s.settimeout(timeout)
169+
try:
170+
s.sendto(wire_query, (server_ip, port))
171+
data, _ = s.recvfrom(4096)
172+
return data
173+
finally:
174+
s.close()
175+
176+
def iterative_resolve(qname, qtype, timeout=2, max_steps=25):
177+
"""
178+
Returns a full DNS response message (bytes) from the final authoritative.
179+
If it cannot fully resolve, returns the last response received.
180+
"""
181+
# Start at root
182+
next_servers = list(ROOT_SERVERS)
183+
last_msg = None
184+
185+
for _step in range(max_steps):
186+
if not next_servers:
187+
break
188+
189+
server = next_servers[0]
190+
next_servers = next_servers[1:]
191+
192+
tid, wire = build_query(qname, qtype, rd=False)
193+
try:
194+
resp = udp_exchange(server, wire, timeout=timeout)
195+
except Exception:
196+
continue
197+
198+
last_msg = resp
199+
parsed = parse_sections(resp)
200+
201+
# If got answers, done (could be CNAME; this simple version returns as-is)
202+
if parsed["answers"]:
203+
return resp
204+
205+
# Otherwise follow referral using NS in authority + glue in additional
206+
ns_names = []
207+
for rr in parsed["authority"]:
208+
if rr["type"] == QTYPE["NS"]:
209+
# NS rdata is a domain name; decode from msg at rdata_off
210+
nsn, _ = decode_name(resp, rr["rdata_off"])
211+
ns_names.append(nsn)
212+
213+
glue_ips = []
214+
for rr in parsed["additional"]:
215+
ip = rr_ip_from_additional(rr)
216+
if ip:
217+
glue_ips.append(ip)
218+
219+
if glue_ips:
220+
# Use glue
221+
next_servers = glue_ips + next_servers
222+
continue
223+
224+
# No glue: resolve NS name A using recursion of *this* iterative resolver
225+
# (bootstraps by resolving ns hostnames)
226+
if ns_names:
227+
resolved_ns_ips = []
228+
for nsn in ns_names[:3]:
229+
# resolve NS hostname to A via iterative
230+
ns_resp = iterative_resolve(nsn, QTYPE["A"], timeout=timeout, max_steps=max_steps)
231+
ns_parsed = parse_sections(ns_resp)
232+
for a_rr in ns_parsed["answers"]:
233+
if a_rr["type"] == QTYPE["A"] and a_rr["rdlen"] == 4:
234+
b = bytearray(a_rr["rdata"])
235+
resolved_ns_ips.append("%d.%d.%d.%d" % (b[0], b[1], b[2], b[3]))
236+
if resolved_ns_ips:
237+
break
238+
if resolved_ns_ips:
239+
next_servers = resolved_ns_ips + next_servers
240+
continue
241+
242+
# If we got here: cannot progress, return what we have
243+
return resp
244+
245+
return last_msg
246+
247+
248+
# -----------------------------
249+
# Simple cache (positive only)
250+
# -----------------------------
251+
class Cache(object):
252+
def __init__(self):
253+
self._lock = threading.Lock()
254+
self._store = {} # key -> (expires_at, wire_response)
255+
256+
def get(self, key):
257+
now = time.time()
258+
with self._lock:
259+
v = self._store.get(key)
260+
if not v:
261+
return None
262+
exp, msg = v
263+
if exp <= now:
264+
del self._store[key]
265+
return None
266+
return msg
267+
268+
def put(self, key, msg, ttl=30):
269+
exp = time.time() + max(1, int(ttl))
270+
with self._lock:
271+
self._store[key] = (exp, msg)
272+
273+
CACHE = Cache()
274+
275+
276+
# -----------------------------
277+
# Stub resolver server (UDP + TCP)
278+
# -----------------------------
279+
def extract_question(msg):
280+
# returns (qname, qtype, qclass, q_off_end)
281+
tid, flags, qd, an, ns, ar, tc, rcode = parse_header(msg)
282+
off = 12
283+
qname, off = decode_name(msg, off)
284+
qtype, qclass = struct.unpack("!HH", msg[off:off+4])
285+
off += 4
286+
return qname, qtype, qclass
287+
288+
def min_ttl_from_answers(resp_msg):
289+
try:
290+
p = parse_sections(resp_msg)
291+
ttls = [rr["ttl"] for rr in p["answers"]] or [30]
292+
return max(1, min(ttls))
293+
except Exception:
294+
return 30
295+
296+
def handle_query_wire(query_wire):
297+
qname, qtype, qclass = extract_question(query_wire)
298+
key = (qname.lower(), qtype, qclass)
299+
300+
cached = CACHE.get(key)
301+
if cached:
302+
return cached
303+
304+
# Iterative upstream (UDP) by default
305+
resp = iterative_resolve(qname, qtype)
306+
307+
# Cache based on min TTL in answer section
308+
ttl = min_ttl_from_answers(resp)
309+
CACHE.put(key, resp, ttl=ttl)
310+
return resp
311+
312+
def udp_server(bind_ip="127.0.0.1", port=5353):
313+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
314+
s.bind((bind_ip, port))
315+
print("UDP DNS stub listening on %s:%d" % (bind_ip, port))
316+
while True:
317+
data, addr = s.recvfrom(4096)
318+
try:
319+
resp = handle_query_wire(data)
320+
except Exception:
321+
# SERVFAIL minimal
322+
tid = data[:2] if len(data) >= 2 else b"\x00\x00"
323+
# flags: QR=1, RCODE=2
324+
resp = tid + struct.pack("!H", 0x8002) + b"\x00\x01\x00\x00\x00\x00\x00\x00" + data[12:]
325+
s.sendto(resp, addr)
326+
327+
def tcp_server(bind_ip="127.0.0.1", port=5353):
328+
ss = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
329+
ss.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
330+
ss.bind((bind_ip, port))
331+
ss.listen(50)
332+
print("TCP DNS stub listening on %s:%d" % (bind_ip, port))
333+
while True:
334+
conn, addr = ss.accept()
335+
threading.Thread(target=_tcp_client, args=(conn,), daemon=False).start()
336+
337+
def _recvn(conn, n):
338+
data = b""
339+
while len(data) < n:
340+
chunk = conn.recv(n - len(data))
341+
if not chunk:
342+
return None
343+
data += chunk
344+
return data
345+
346+
def _tcp_client(conn):
347+
try:
348+
hdr = _recvn(conn, 2)
349+
if not hdr:
350+
return
351+
(ln,) = struct.unpack("!H", hdr)
352+
q = _recvn(conn, ln)
353+
if not q:
354+
return
355+
try:
356+
resp = handle_query_wire(q)
357+
except Exception:
358+
resp = make_servfail(q)
359+
conn.sendall(struct.pack("!H", len(resp)) + resp)
360+
finally:
361+
try:
362+
conn.close()
363+
except Exception:
364+
pass
365+
366+
def run_stub(bind_ip="127.0.0.1", port=5353):
367+
t1 = threading.Thread(target=udp_server, args=(bind_ip, port))
368+
t2 = threading.Thread(target=tcp_server, args=(bind_ip, port))
369+
t1.daemon = True
370+
t2.daemon = True
371+
t1.start()
372+
t2.start()
373+
while True:
374+
time.sleep(3600)
375+
376+
if __name__ == "__main__":
377+
run_stub("127.0.0.1", 5353)

0 commit comments

Comments
 (0)