Skip to content

Commit 8a6f3ab

Browse files
committed
concurrent linked queue
1 parent 6bd8926 commit 8a6f3ab

2 files changed

Lines changed: 280 additions & 0 deletions

File tree

concurrent/queue/linked_queue.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright 2023 igevin
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queue
16+
17+
import (
18+
"github.com/igevin/algokit/collection/queue"
19+
"sync/atomic"
20+
"unsafe"
21+
)
22+
23+
// ConcurrentLinkedQueue 无界并发安全队列
24+
type ConcurrentLinkedQueue[T any] struct {
25+
// *node[T]
26+
head unsafe.Pointer
27+
// *node[T]
28+
tail unsafe.Pointer
29+
}
30+
31+
func NewConcurrentLinkedQueue[T any]() *ConcurrentLinkedQueue[T] {
32+
head := &node[T]{}
33+
ptr := unsafe.Pointer(head)
34+
return &ConcurrentLinkedQueue[T]{
35+
head: ptr,
36+
tail: ptr,
37+
}
38+
}
39+
40+
func (c *ConcurrentLinkedQueue[T]) Enqueue(t T) error {
41+
newNode := &node[T]{val: t}
42+
newPtr := unsafe.Pointer(newNode)
43+
for {
44+
tailPtr := atomic.LoadPointer(&c.tail)
45+
tail := (*node[T])(tailPtr)
46+
tailNext := atomic.LoadPointer(&tail.next)
47+
if tailNext != nil {
48+
// 已经被人修改了,我们不需要修复,因为预期中修改的那个人会把 c.tail 指过去
49+
continue
50+
}
51+
if atomic.CompareAndSwapPointer(&tail.next, tailNext, newPtr) {
52+
// 如果失败也不用担心,说明有人抢先一步了
53+
atomic.CompareAndSwapPointer(&c.tail, tailPtr, newPtr)
54+
return nil
55+
}
56+
}
57+
}
58+
59+
func (c *ConcurrentLinkedQueue[T]) Dequeue() (T, error) {
60+
for {
61+
headPtr := atomic.LoadPointer(&c.head)
62+
head := (*node[T])(headPtr)
63+
tailPtr := atomic.LoadPointer(&c.tail)
64+
tail := (*node[T])(tailPtr)
65+
if head == tail {
66+
// 不需要做更多检测,在当下这一刻,我们就认为没有元素,即便这时候正好有人入队
67+
// 但是并不妨碍我们在它彻底入队完成——即所有的指针都调整好——之前,
68+
// 认为其实还是没有元素
69+
var t T
70+
return t, queue.ErrEmptyQueue
71+
}
72+
headNextPtr := atomic.LoadPointer(&head.next)
73+
if atomic.CompareAndSwapPointer(&c.head, headPtr, headNextPtr) {
74+
headNext := (*node[T])(headNextPtr)
75+
return headNext.val, nil
76+
}
77+
}
78+
}
79+
80+
type node[T any] struct {
81+
val T
82+
// *node[T]
83+
next unsafe.Pointer
84+
}
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Copyright 2023 igevin
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queue
16+
17+
import (
18+
"fmt"
19+
"github.com/stretchr/testify/assert"
20+
"github.com/stretchr/testify/require"
21+
"math/rand"
22+
"sync"
23+
"sync/atomic"
24+
"testing"
25+
)
26+
27+
func TestConcurrentQueue_Enqueue(t *testing.T) {
28+
t.Parallel()
29+
testCases := []struct {
30+
name string
31+
q func() *ConcurrentLinkedQueue[int]
32+
val int
33+
34+
wantData []int
35+
wantErr error
36+
}{
37+
{
38+
name: "empty",
39+
q: func() *ConcurrentLinkedQueue[int] {
40+
return NewConcurrentLinkedQueue[int]()
41+
},
42+
val: 123,
43+
wantData: []int{123},
44+
},
45+
{
46+
name: "multiple",
47+
q: func() *ConcurrentLinkedQueue[int] {
48+
q := NewConcurrentLinkedQueue[int]()
49+
err := q.Enqueue(123)
50+
require.NoError(t, err)
51+
return q
52+
},
53+
val: 234,
54+
wantData: []int{123, 234},
55+
},
56+
}
57+
58+
for _, tc := range testCases {
59+
t.Run(tc.name, func(t *testing.T) {
60+
q := tc.q()
61+
err := q.Enqueue(tc.val)
62+
assert.Equal(t, tc.wantErr, err)
63+
assert.Equal(t, tc.wantData, q.asSlice())
64+
})
65+
}
66+
}
67+
68+
func TestConcurrentQueue_Dequeue(t *testing.T) {
69+
t.Parallel()
70+
testCases := []struct {
71+
name string
72+
q func() *ConcurrentLinkedQueue[int]
73+
wantVal int
74+
wantData []int
75+
wantErr error
76+
}{
77+
{
78+
name: "empty",
79+
q: func() *ConcurrentLinkedQueue[int] {
80+
q := NewConcurrentLinkedQueue[int]()
81+
return q
82+
},
83+
wantErr: errEmptyQueue,
84+
},
85+
{
86+
name: "single",
87+
q: func() *ConcurrentLinkedQueue[int] {
88+
q := NewConcurrentLinkedQueue[int]()
89+
err := q.Enqueue(123)
90+
assert.NoError(t, err)
91+
return q
92+
},
93+
wantVal: 123,
94+
},
95+
{
96+
name: "multiple",
97+
q: func() *ConcurrentLinkedQueue[int] {
98+
q := NewConcurrentLinkedQueue[int]()
99+
err := q.Enqueue(123)
100+
assert.NoError(t, err)
101+
err = q.Enqueue(234)
102+
assert.NoError(t, err)
103+
return q
104+
},
105+
wantVal: 123,
106+
wantData: []int{234},
107+
},
108+
{
109+
name: "enqueue and dequeue",
110+
q: func() *ConcurrentLinkedQueue[int] {
111+
q := NewConcurrentLinkedQueue[int]()
112+
err := q.Enqueue(123)
113+
assert.NoError(t, err)
114+
err = q.Enqueue(234)
115+
assert.NoError(t, err)
116+
val, err := q.Dequeue()
117+
assert.Equal(t, 123, val)
118+
assert.NoError(t, err)
119+
err = q.Enqueue(345)
120+
assert.NoError(t, err)
121+
return q
122+
},
123+
wantVal: 234,
124+
wantData: []int{345},
125+
},
126+
}
127+
128+
for _, tc := range testCases {
129+
t.Run(tc.name, func(t *testing.T) {
130+
q := tc.q()
131+
val, err := q.Dequeue()
132+
assert.Equal(t, tc.wantErr, err)
133+
if err != nil {
134+
return
135+
}
136+
assert.Equal(t, tc.wantVal, val)
137+
assert.Equal(t, tc.wantData, q.asSlice())
138+
})
139+
}
140+
}
141+
142+
func TestConcurrentLinkedQueue(t *testing.T) {
143+
t.Parallel()
144+
// 仅仅是为了测试在入队出队期间不会出现 panic 或者死循环之类的问题
145+
// FIFO 特性参考其余测试
146+
q := NewConcurrentLinkedQueue[int]()
147+
var wg sync.WaitGroup
148+
wg.Add(10000)
149+
for i := 0; i < 10; i++ {
150+
go func() {
151+
for j := 0; j < 1000; j++ {
152+
val := rand.Int()
153+
_ = q.Enqueue(val)
154+
}
155+
}()
156+
}
157+
var cnt int32 = 0
158+
for i := 0; i < 10; i++ {
159+
go func() {
160+
for {
161+
if atomic.LoadInt32(&cnt) >= 10000 {
162+
return
163+
}
164+
_, err := q.Dequeue()
165+
if err == nil {
166+
atomic.AddInt32(&cnt, 1)
167+
wg.Done()
168+
}
169+
}
170+
}()
171+
}
172+
wg.Wait()
173+
}
174+
175+
func (c *ConcurrentLinkedQueue[T]) asSlice() []T {
176+
var res []T
177+
cur := (*node[T])((*node[T])(c.head).next)
178+
for cur != nil {
179+
res = append(res, cur.val)
180+
cur = (*node[T])(cur.next)
181+
}
182+
return res
183+
}
184+
185+
func ExampleNewConcurrentLinkedQueue() {
186+
q := NewConcurrentLinkedQueue[int]()
187+
_ = q.Enqueue(10)
188+
val, err := q.Dequeue()
189+
if err != nil {
190+
// 一般意味着队列为空
191+
fmt.Println(err)
192+
}
193+
fmt.Println(val)
194+
// Output:
195+
// 10
196+
}

0 commit comments

Comments
 (0)