Skip to content

Commit b41f4cc

Browse files
committed
device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now taking on the additional task of connecting up parent pointers. This will be handy in the following commit. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
1 parent 4a57024 commit b41f4cc

3 files changed

Lines changed: 95 additions & 59 deletions

File tree

device/allowedips.go

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@ import (
1414
"unsafe"
1515
)
1616

17+
type parentIndirection struct {
18+
parentBit **trieEntry
19+
parentBitType uint8
20+
}
21+
1722
type trieEntry struct {
1823
peer *Peer
1924
child [2]*trieEntry
25+
parent parentIndirection
2026
cidr uint8
2127
bitAtByte uint8
2228
bitAtShift uint8
@@ -114,78 +120,107 @@ func (node *trieEntry) maskSelf() {
114120
}
115121
}
116122

117-
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
118-
119-
// at leaf
123+
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
124+
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
125+
parent = node
126+
if parent.cidr == cidr {
127+
exact = true
128+
return
129+
}
130+
bit := node.choose(ip)
131+
node = node.child[bit]
132+
}
133+
return
134+
}
120135

121-
if node == nil {
136+
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
137+
if *trie.parentBit == nil {
122138
node := &trieEntry{
123-
bits: ip,
124139
peer: peer,
140+
parent: trie,
141+
bits: ip,
125142
cidr: cidr,
126143
bitAtByte: cidr / 8,
127144
bitAtShift: 7 - (cidr % 8),
128145
}
129146
node.maskSelf()
130147
node.addToPeerEntries()
131-
return node
148+
*trie.parentBit = node
149+
return
132150
}
133-
134-
// traverse deeper
135-
136-
common := commonBits(node.bits, ip)
137-
if node.cidr <= cidr && common >= node.cidr {
138-
if node.cidr == cidr {
139-
node.removeFromPeerEntries()
140-
node.peer = peer
141-
node.addToPeerEntries()
142-
return node
143-
}
144-
bit := node.choose(ip)
145-
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
146-
return node
151+
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
152+
if exact {
153+
node.removeFromPeerEntries()
154+
node.peer = peer
155+
node.addToPeerEntries()
156+
return
147157
}
148158

149-
// split node
150-
151159
newNode := &trieEntry{
152-
bits: ip,
153160
peer: peer,
161+
bits: ip,
154162
cidr: cidr,
155163
bitAtByte: cidr / 8,
156164
bitAtShift: 7 - (cidr % 8),
157165
}
158166
newNode.maskSelf()
159167
newNode.addToPeerEntries()
160168

169+
var down *trieEntry
170+
if node == nil {
171+
down = *trie.parentBit
172+
} else {
173+
bit := node.choose(ip)
174+
down = node.child[bit]
175+
if down == nil {
176+
newNode.parent = parentIndirection{&node.child[bit], bit}
177+
node.child[bit] = newNode
178+
return
179+
}
180+
}
181+
common := commonBits(down.bits, ip)
161182
if common < cidr {
162183
cidr = common
163184
}
164-
165-
// check for shorter prefix
185+
parent := node
166186

167187
if newNode.cidr == cidr {
168-
bit := newNode.choose(node.bits)
169-
newNode.child[bit] = node
170-
return newNode
188+
bit := newNode.choose(down.bits)
189+
down.parent = parentIndirection{&newNode.child[bit], bit}
190+
newNode.child[bit] = down
191+
if parent == nil {
192+
newNode.parent = trie
193+
*trie.parentBit = newNode
194+
} else {
195+
bit := parent.choose(newNode.bits)
196+
newNode.parent = parentIndirection{&parent.child[bit], bit}
197+
parent.child[bit] = newNode
198+
}
199+
return
171200
}
172201

173-
// create new parent for node & newNode
174-
175-
parent := &trieEntry{
176-
bits: append([]byte{}, ip...),
177-
peer: nil,
202+
node = &trieEntry{
203+
bits: append([]byte{}, newNode.bits...),
178204
cidr: cidr,
179205
bitAtByte: cidr / 8,
180206
bitAtShift: 7 - (cidr % 8),
181207
}
182-
parent.maskSelf()
183-
184-
bit := parent.choose(ip)
185-
parent.child[bit] = newNode
186-
parent.child[bit^1] = node
187-
188-
return parent
208+
node.maskSelf()
209+
210+
bit := node.choose(down.bits)
211+
down.parent = parentIndirection{&node.child[bit], bit}
212+
node.child[bit] = down
213+
bit = node.choose(newNode.bits)
214+
newNode.parent = parentIndirection{&node.child[bit], bit}
215+
node.child[bit] = newNode
216+
if parent == nil {
217+
node.parent = trie
218+
*trie.parentBit = node
219+
} else {
220+
bit := parent.choose(node.bits)
221+
node.parent = parentIndirection{&parent.child[bit], bit}
222+
parent.child[bit] = node
223+
}
189224
}
190225

191226
func (node *trieEntry) lookup(ip net.IP) *Peer {
@@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
236271

237272
switch len(ip) {
238273
case net.IPv6len:
239-
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
274+
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
240275
case net.IPv4len:
241-
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
276+
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
242277
default:
243278
panic(errors.New("inserting unknown address type"))
244279
}

device/allowedips_rand_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
6565
}
6666

6767
func TestTrieRandomIPv4(t *testing.T) {
68-
var trie *trieEntry
6968
var slow SlowRouter
7069
var peers []*Peer
70+
var allowedIPs AllowedIPs
7171

7272
rand.Seed(1)
7373

@@ -82,25 +82,25 @@ func TestTrieRandomIPv4(t *testing.T) {
8282
rand.Read(addr[:])
8383
cidr := uint8(rand.Uint32() % (AddressLength * 8))
8484
index := rand.Int() % NumberOfPeers
85-
trie = trie.insert(addr[:], cidr, peers[index])
85+
allowedIPs.Insert(addr[:], cidr, peers[index])
8686
slow = slow.Insert(addr[:], cidr, peers[index])
8787
}
8888

8989
for n := 0; n < NumberOfTests; n++ {
9090
var addr [AddressLength]byte
9191
rand.Read(addr[:])
9292
peer1 := slow.Lookup(addr[:])
93-
peer2 := trie.lookup(addr[:])
93+
peer2 := allowedIPs.LookupIPv4(addr[:])
9494
if peer1 != peer2 {
9595
t.Error("Trie did not match naive implementation, for:", addr)
9696
}
9797
}
9898
}
9999

100100
func TestTrieRandomIPv6(t *testing.T) {
101-
var trie *trieEntry
102101
var slow SlowRouter
103102
var peers []*Peer
103+
var allowedIPs AllowedIPs
104104

105105
rand.Seed(1)
106106

@@ -115,15 +115,15 @@ func TestTrieRandomIPv6(t *testing.T) {
115115
rand.Read(addr[:])
116116
cidr := uint8(rand.Uint32() % (AddressLength * 8))
117117
index := rand.Int() % NumberOfPeers
118-
trie = trie.insert(addr[:], cidr, peers[index])
118+
allowedIPs.Insert(addr[:], cidr, peers[index])
119119
slow = slow.Insert(addr[:], cidr, peers[index])
120120
}
121121

122122
for n := 0; n < NumberOfTests; n++ {
123123
var addr [AddressLength]byte
124124
rand.Read(addr[:])
125125
peer1 := slow.Lookup(addr[:])
126-
peer2 := trie.lookup(addr[:])
126+
peer2 := allowedIPs.LookupIPv6(addr[:])
127127
if peer1 != peer2 {
128128
t.Error("Trie did not match naive implementation, for:", addr)
129129
}

device/allowedips_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) {
4242
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
4343
var trie *trieEntry
4444
var peers []*Peer
45+
root := parentIndirection{&trie, 2}
4546

4647
rand.Seed(1)
4748

@@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
5657
rand.Read(addr[:])
5758
cidr := uint8(rand.Uint32() % (AddressLength * 8))
5859
index := rand.Int() % peerNumber
59-
trie = trie.insert(addr[:], cidr, peers[index])
60+
root.insert(addr[:], cidr, peers[index])
6061
}
6162

6263
for n := 0; n < b.N; n++ {
@@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
9495
g := &Peer{}
9596
h := &Peer{}
9697

97-
var trie *trieEntry
98+
var allowedIPs AllowedIPs
9899

99100
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
100-
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
101+
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
101102
}
102103

103104
assertEQ := func(peer *Peer, a, b, c, d byte) {
104-
p := trie.lookup([]byte{a, b, c, d})
105+
p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
105106
if p != peer {
106107
t.Error("Assert EQ failed")
107108
}
108109
}
109110

110111
assertNEQ := func(peer *Peer, a, b, c, d byte) {
111-
p := trie.lookup([]byte{a, b, c, d})
112+
p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
112113
if p == peer {
113114
t.Error("Assert NEQ failed")
114115
}
@@ -150,20 +151,20 @@ func TestTrieIPv4(t *testing.T) {
150151
assertEQ(a, 192, 0, 0, 0)
151152
assertEQ(a, 255, 0, 0, 0)
152153

153-
trie = trie.removeByPeer(a)
154+
allowedIPs.RemoveByPeer(a)
154155

155156
assertNEQ(a, 1, 0, 0, 0)
156157
assertNEQ(a, 64, 0, 0, 0)
157158
assertNEQ(a, 128, 0, 0, 0)
158159
assertNEQ(a, 192, 0, 0, 0)
159160
assertNEQ(a, 255, 0, 0, 0)
160161

161-
trie = nil
162+
allowedIPs = AllowedIPs{}
162163

163164
insert(a, 192, 168, 0, 0, 16)
164165
insert(a, 192, 168, 0, 0, 24)
165166

166-
trie = trie.removeByPeer(a)
167+
allowedIPs.RemoveByPeer(a)
167168

168169
assertNEQ(a, 192, 168, 0, 1)
169170
}
@@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) {
181182
g := &Peer{}
182183
h := &Peer{}
183184

184-
var trie *trieEntry
185+
var allowedIPs AllowedIPs
185186

186187
expand := func(a uint32) []byte {
187188
var out [4]byte
@@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) {
198199
addr = append(addr, expand(b)...)
199200
addr = append(addr, expand(c)...)
200201
addr = append(addr, expand(d)...)
201-
trie = trie.insert(addr, cidr, peer)
202+
allowedIPs.Insert(addr, cidr, peer)
202203
}
203204

204205
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
207208
addr = append(addr, expand(b)...)
208209
addr = append(addr, expand(c)...)
209210
addr = append(addr, expand(d)...)
210-
p := trie.lookup(addr)
211+
p := allowedIPs.LookupIPv6(addr)
211212
if p != peer {
212213
t.Error("Assert EQ failed")
213214
}

0 commit comments

Comments
 (0)