Skip to content

Commit 1fc910a

Browse files
fix(reader): 读取数据时的边界检查和错误处理 (#148) (#149)
* Update README.md * fix(reader): 修复读取数据时的边界检查和错误处理 在Reader的多个方法中添加了对缓冲区边界和读取错误的检查,确保在数据不足或读取失败时返回安全值。同时,新增了测试文件reader_test.go,验证了空缓冲区、不完整数据、错误Reader等情况下的行为。 * perf(reader): 使用unsafe.Slice优化ReadU8性能 通过使用unsafe.Slice直接读取数据到变量中,避免了额外的内存分配和拷贝操作,从而提高了ReadU8函数的性能。 * refactor(reader): 简化ReadU8函数中的错误处理逻辑
1 parent 92282d2 commit 1fc910a

2 files changed

Lines changed: 197 additions & 3 deletions

File tree

utils/binary/reader.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,15 @@ func (r *Reader) ReadAll() []byte {
8181

8282
func (r *Reader) ReadU8() (v uint8) {
8383
if r.reader != nil {
84-
_, _ = r.reader.Read(unsafe.Slice(&v, 1))
84+
_, err := r.reader.Read(unsafe.Slice(&v, 1))
85+
if err != nil {
86+
return 0
87+
}
8588
return
8689
}
90+
if r.pos >= len(r.buffer) {
91+
return 0
92+
}
8793
v = r.buffer[r.pos]
8894
r.pos++
8995
return
@@ -93,8 +99,16 @@ func readint[T ~uint16 | ~uint32 | ~uint64](r *Reader) (v T) {
9399
sz := unsafe.Sizeof(v)
94100
buf := make([]byte, 8)
95101
if r.reader != nil {
96-
_, _ = r.reader.Read(buf[8-sz:])
102+
n, err := r.reader.Read(buf[8-sz:])
103+
if err != nil || n < int(sz) {
104+
// 读取失败或读取的数据不足,返回零值
105+
return 0
106+
}
97107
} else {
108+
// 确保缓冲区有足够的数据
109+
if r.pos+int(sz) > len(r.buffer) {
110+
return 0
111+
}
98112
copy(buf[8-sz:], r.buffer[r.pos:r.pos+int(sz)])
99113
r.pos += int(sz)
100114
}
@@ -129,6 +143,10 @@ func (r *Reader) ReadBytesNoCopy(length int) (v []byte) {
129143
if r.reader != nil {
130144
return r.ReadBytes(length)
131145
}
146+
// 确保缓冲区有足够的数据
147+
if r.pos+length > len(r.buffer) {
148+
return make([]byte, 0)
149+
}
132150
v = r.buffer[r.pos : r.pos+length]
133151
r.pos += length
134152
return
@@ -138,8 +156,16 @@ func (r *Reader) ReadBytes(length int) (v []byte) {
138156
// 返回一个全新的数组罢
139157
v = make([]byte, length)
140158
if r.reader != nil {
141-
_, _ = r.reader.Read(v)
159+
n, err := io.ReadFull(r.reader, v)
160+
if err != nil || n < length {
161+
// 读取失败或读取的数据不足,返回空数组
162+
return make([]byte, 0)
163+
}
142164
} else {
165+
// 确保缓冲区有足够的数据
166+
if r.pos+length > len(r.buffer) {
167+
return make([]byte, 0)
168+
}
143169
copy(v, r.buffer[r.pos:r.pos+length])
144170
r.pos += length
145171
}

utils/binary/reader_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package binary
2+
3+
import (
4+
"io"
5+
"testing"
6+
)
7+
8+
// TestReaderEmptyBuffer 测试空缓冲区的情况
9+
func TestReaderEmptyBuffer(t *testing.T) {
10+
// 测试空缓冲区
11+
r := NewReader([]byte{})
12+
13+
// 测试ReadU8
14+
if v := r.ReadU8(); v != 0 {
15+
t.Errorf("ReadU8 with empty buffer should return 0, got %d", v)
16+
}
17+
18+
// 测试ReadU16
19+
if v := r.ReadU16(); v != 0 {
20+
t.Errorf("ReadU16 with empty buffer should return 0, got %d", v)
21+
}
22+
23+
// 测试ReadU32
24+
if v := r.ReadU32(); v != 0 {
25+
t.Errorf("ReadU32 with empty buffer should return 0, got %d", v)
26+
}
27+
28+
// 测试ReadU64
29+
if v := r.ReadU64(); v != 0 {
30+
t.Errorf("ReadU64 with empty buffer should return 0, got %d", v)
31+
}
32+
33+
// 测试ReadBytes
34+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
35+
t.Errorf("ReadBytes with empty buffer should return empty slice, got %v", bytes)
36+
}
37+
38+
// 测试ReadBytesNoCopy
39+
if bytes := r.ReadBytesNoCopy(10); len(bytes) != 0 {
40+
t.Errorf("ReadBytesNoCopy with empty buffer should return empty slice, got %v", bytes)
41+
}
42+
}
43+
44+
// TestReaderIncompleteData 测试不完整数据的情况
45+
func TestReaderIncompleteData(t *testing.T) {
46+
// 测试不完整数据 - 只有1个字节
47+
r := NewReader([]byte{0x01})
48+
49+
// 测试ReadU16 (需要2字节)
50+
if v := r.ReadU16(); v != 0 {
51+
t.Errorf("ReadU16 with incomplete data should return 0, got %d", v)
52+
}
53+
54+
// 重置Reader
55+
r = NewReader([]byte{0x01, 0x02})
56+
57+
// 测试ReadU32 (需要4字节)
58+
if v := r.ReadU32(); v != 0 {
59+
t.Errorf("ReadU32 with incomplete data should return 0, got %d", v)
60+
}
61+
62+
// 测试ReadBytes超出可用长度
63+
r = NewReader([]byte{0x01, 0x02, 0x03})
64+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
65+
t.Errorf("ReadBytes with insufficient data should return empty slice, got %v", bytes)
66+
}
67+
}
68+
69+
// TestReaderWithIOReader 测试使用io.Reader的情况
70+
func TestReaderWithIOReader(t *testing.T) {
71+
// 创建一个会返回错误的Reader
72+
errReader := &errorReader{}
73+
r := ParseReader(errReader)
74+
75+
// 测试ReadU8
76+
if v := r.ReadU8(); v != 0 {
77+
t.Errorf("ReadU8 with error reader should return 0, got %d", v)
78+
}
79+
80+
// 测试ReadU16
81+
if v := r.ReadU16(); v != 0 {
82+
t.Errorf("ReadU16 with error reader should return 0, got %d", v)
83+
}
84+
85+
// 测试ReadU32
86+
if v := r.ReadU32(); v != 0 {
87+
t.Errorf("ReadU32 with error reader should return 0, got %d", v)
88+
}
89+
90+
// 测试ReadBytes
91+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
92+
t.Errorf("ReadBytes with error reader should return empty slice, got %v", bytes)
93+
}
94+
95+
// 测试ReadAll
96+
if data := r.ReadAll(); data != nil {
97+
t.Errorf("ReadAll with error reader should return nil, got %v", data)
98+
}
99+
}
100+
101+
// TestReaderNormalData 测试正常数据的情况
102+
func TestReaderNormalData(t *testing.T) {
103+
// 准备测试数据
104+
data := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}
105+
r := NewReader(data)
106+
107+
// 测试ReadU8
108+
if v := r.ReadU8(); v != 0x01 {
109+
t.Errorf("ReadU8 should return 0x01, got 0x%02x", v)
110+
}
111+
112+
// 测试ReadU16
113+
if v := r.ReadU16(); v != 0x0203 {
114+
t.Errorf("ReadU16 should return 0x0203, got 0x%04x", v)
115+
}
116+
117+
// 测试ReadU32
118+
if v := r.ReadU32(); v != 0x04050607 {
119+
t.Errorf("ReadU32 should return 0x04050607, got 0x%08x", v)
120+
}
121+
122+
// 测试ReadByte
123+
if b, err := r.ReadByte(); err != nil || b != 0x08 {
124+
t.Errorf("ReadByte should return 0x08, got 0x%02x, err: %v", b, err)
125+
}
126+
127+
// 测试读取完所有数据后的ReadByte
128+
if _, err := r.ReadByte(); err != io.EOF {
129+
t.Errorf("ReadByte after end should return EOF, got %v", err)
130+
}
131+
}
132+
133+
// TestReaderShortRead 测试短读的情况
134+
func TestReaderShortRead(t *testing.T) {
135+
// 创建一个会返回短读的Reader
136+
shortReader := &shortReader{data: []byte{0x01, 0x02, 0x03, 0x04}}
137+
r := ParseReader(shortReader)
138+
139+
// 测试ReadBytes
140+
if bytes := r.ReadBytes(10); len(bytes) != 0 {
141+
t.Errorf("ReadBytes with short reader should return empty slice, got %v", bytes)
142+
}
143+
}
144+
145+
// 辅助测试结构
146+
147+
// errorReader 总是返回错误的Reader
148+
type errorReader struct{}
149+
150+
func (r *errorReader) Read(p []byte) (n int, err error) {
151+
return 0, io.ErrUnexpectedEOF
152+
}
153+
154+
// shortReader 总是返回短读的Reader
155+
type shortReader struct {
156+
data []byte
157+
pos int
158+
}
159+
160+
func (r *shortReader) Read(p []byte) (n int, err error) {
161+
if r.pos >= len(r.data) {
162+
return 0, io.EOF
163+
}
164+
// 只读取一个字节,模拟短读
165+
p[0] = r.data[r.pos]
166+
r.pos++
167+
return 1, nil
168+
}

0 commit comments

Comments
 (0)