Skip to content

Commit 3f28bbf

Browse files
committed
netstack/icmp: from icmp responder use icmp forwarder
1 parent dfa0dca commit 3f28bbf

3 files changed

Lines changed: 57 additions & 17 deletions

File tree

intra/netstack/dispatchers.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,7 @@ func (d *readVDispatcher) io(fds *fds) (bool, tcpip.Error) {
320320
defer pkt.DecRef()
321321

322322
if d.icmp.ok() {
323-
p := pkt.Data()
324-
if d.icmp.handle(pkt.NICID, p.ToBuffer()) {
323+
if d.icmp.respond(pkt) {
325324
return cont, nil
326325
}
327326
}

intra/netstack/icmp.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ func OutboundICMP(id string, s *stack.Stack, hdl GICMPHandler) {
4141
return
4242
}
4343

44-
setICMPEchoHandler(hdl)
45-
4644
forwarder := newIcmpForwarder(id, s, hdl)
4745
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, forwarder.reply4)
4846
s.SetTransportProtocolHandler(icmp.ProtocolNumber6, forwarder.reply6)
47+
setICMPEchoHandler(forwarder)
4948
}
5049

5150
func newIcmpForwarder(owner string, s *stack.Stack, h GICMPHandler) *icmpForwarder {

intra/netstack/icmpecho.go

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ import (
1818
const minICMPPacketSize = header.ICMPv4MinimumSize + header.IPv4MinimumSize
1919
const typicalICMPEchoPayloadSize = 64 // or 56
2020
const expectedICMPPacketSize = header.IPv6MinimumSize + header.ICMPv6MinimumSize + typicalICMPEchoPayloadSize
21+
const useIcmpForwarder = true
2122

2223
// TODO: get rid of the global in favor of passing the handler via the responder.
2324
// hdlEcho stores the ICMP handler used by the dispatcher-level ICMP
2425
// interception path.
25-
var hdlEcho = core.NewZeroVolatile[GICMPHandler]()
26+
var hdlEcho = core.NewZeroVolatile[*icmpForwarder]()
2627

27-
func setICMPEchoHandler(h GICMPHandler) {
28+
func setICMPEchoHandler(h *icmpForwarder) {
2829
hdlEcho.Store(h)
2930
}
3031

@@ -57,20 +58,30 @@ func (r *icmpResponder) ok() bool {
5758
return r != nil && r.open.Load() && r.ep != nil
5859
}
5960

61+
func (r *icmpResponder) respond(pkt *stack.PacketBuffer) (handled bool) {
62+
if !r.ok() {
63+
return
64+
}
65+
defer pkt.DecRef()
66+
67+
return r.handle(pkt.NICID, pkt.IncRef())
68+
}
69+
6070
// handle returns true if the packet is ICMP and is handled (or dropped) by the
6171
// bypass path.
62-
func (r *icmpResponder) handle(nic tcpip.NICID, b buffer.Buffer) (handled bool) {
72+
func (r *icmpResponder) handle(nic tcpip.NICID, pkt *stack.PacketBuffer) (handled bool) {
6373
if !r.ok() {
6474
return
6575
}
6676

67-
inSize := b.Size()
77+
inSize := pkt.Size()
6878
if inSize <= minICMPPacketSize {
6979
// too verbose: log.VV("icmp: responder: packet too small: %d", inSize)
7080
// Too small to be a valid ICMP echo request.
7181
return
7282
}
7383

84+
b := pkt.ToBuffer()
7485
h := hdlEcho.Load()
7586
var w core.ByteWriter
7687
defer w.Close()
@@ -79,13 +90,14 @@ func (r *icmpResponder) handle(nic tcpip.NICID, b buffer.Buffer) (handled bool)
7990
n, err := b.ReadToWriter(&w, expectedICMPPacketSize)
8091

8192
if settings.Debug {
82-
logeif(err)("icmp: responder: read to writer (sz: %d / %d / %d); h? %t, err? %v", n, w.Len(), inSize, h != nil, err)
93+
logeif(err)("icmp: responder: read to writer (sz: %d / %d / %d); h? %t / fwd? %t, err? %v",
94+
n, w.Len(), inSize, h != nil, useIcmpForwarder, err)
8395
}
8496
if err != nil || n == 0 || h == nil || w.Len() == 0 {
8597
return
8698
}
8799

88-
truncated := inSize > int64(w.Len())
100+
truncated := inSize > w.Len()
89101
parsed := wire.Pool.Get()
90102
parsed.DecodeTrunc(w.Copy(), truncated)
91103

@@ -117,7 +129,7 @@ func (r *icmpResponder) handle(nic tcpip.NICID, b buffer.Buffer) (handled bool)
117129
// There is more data beyond the minimum ICMP echo request.
118130
// Reconstruct the full packet.
119131
w.Reset()
120-
b.ReadToWriter(&w, inSize)
132+
b.ReadToWriter(&w, int64(inSize))
121133
parsed.Decode(w.Copy())
122134
}
123135

@@ -130,17 +142,47 @@ func (r *icmpResponder) handle(nic tcpip.NICID, b buffer.Buffer) (handled bool)
130142
return
131143
}
132144

133-
// Process asynchronously to avoid blocking the dispatcher loop.
134-
core.Go("icmp.responder", func() {
135-
r.process(h, nic, parsed, src, dst)
136-
})
145+
if useIcmpForwarder {
146+
wire.Pool.Put(parsed)
147+
return r.forward(h, pkt, src, dst)
148+
} else {
149+
// Process asynchronously to avoid blocking the dispatcher loop.
150+
core.Go("icmp.responder", func() {
151+
r.process(h, nic, parsed, src, dst)
152+
})
153+
}
137154

138155
return true
139156
}
140157

158+
func (r *icmpResponder) forward(h *icmpForwarder, pkt *stack.PacketBuffer, src, dst netip.AddrPort) bool {
159+
pkt.IncRef()
160+
defer pkt.DecRef()
161+
162+
var id stack.TransportEndpointID
163+
// local is dst / remote is src; see: netstack/icmp/icmp.go:func (h *icmpForwarder) reply4
164+
// and netstack/icmp/icmp.go:func (h *icmpForwarder) reply6
165+
id.LocalAddress = tcpip.AddrFrom16Slice(dst.Addr().AsSlice())
166+
id.RemoteAddress = tcpip.AddrFrom16Slice(src.Addr().AsSlice())
167+
// ICMP does not use ports, so they remain zero.
168+
id.LocalPort = 0
169+
id.RemotePort = 0
170+
171+
switch pkt.NetworkProtocolNumber {
172+
case header.IPv4ProtocolNumber:
173+
return h.reply4(id, pkt)
174+
case header.IPv6ProtocolNumber:
175+
return h.reply6(id, pkt)
176+
}
177+
178+
log.W("icmp: responder: unsupported proto: %d; %s => %s",
179+
pkt.NetworkProtocolNumber, src, dst)
180+
return false
181+
}
182+
141183
// process handles the ICMP echo request and injects the reply back into the TUN.
142184
// The parsed packet is released back to the pool after processing.
143-
func (r *icmpResponder) process(h GICMPHandler, nic tcpip.NICID, pkt *wire.Parsed, src, dst netip.AddrPort) {
185+
func (r *icmpResponder) process(h *icmpForwarder, nic tcpip.NICID, pkt *wire.Parsed, src, dst netip.AddrPort) {
144186
defer wire.Pool.Put(pkt)
145187

146188
icmpMsg := pkt.Transport()
@@ -156,7 +198,7 @@ func (r *icmpResponder) process(h GICMPHandler, nic tcpip.NICID, pkt *wire.Parse
156198
return
157199
}
158200

159-
pinged := h.Ping(icmpMsg, src, dst)
201+
pinged := h.h.Ping(icmpMsg, src, dst)
160202

161203
resp, proto, tag, err := r.echoReply(pkt, payload, pinged)
162204
notok = err != nil || len(resp) == 0

0 commit comments

Comments
 (0)