Skip to content

Commit 6296f31

Browse files
Merge pull request #20 from sqlitecloud/convert-custom-types
feat: support conversion of custom types
2 parents 96998e1 + c0c17b1 commit 6296f31

File tree

2 files changed

+163
-95
lines changed

2 files changed

+163
-95
lines changed

chunk.go

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -222,27 +222,49 @@ func (this *Value) readBufferAt(chunk *Chunk, offset uint64) (uint64, error) {
222222
return 0, errors.New("Unsuported type")
223223
}
224224

225-
func protocolBufferFromValue(v interface{}) [][]byte {
225+
func protocolBufferFromValue(v interface{}) ([][]byte, error) {
226226
switch v := v.(type) {
227227
case nil:
228-
return protocolBufferFromNull()
228+
return protocolBufferFromNull(), nil
229229
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
230-
return protocolBufferFromInt(v)
231-
case float32, float64:
232-
return protocolBufferFromFloat(v)
230+
return protocolBufferFromInt(v), nil
231+
case float32:
232+
return protocolBufferFromFloat(float64(v)), nil
233+
case float64:
234+
return protocolBufferFromFloat(v), nil
233235
case string:
234-
return protocolBufferFromString(v, true)
236+
return protocolBufferFromString(v, true), nil
235237
case []byte:
236-
return protocolBufferFromBytes(v)
238+
return protocolBufferFromBytes(v), nil
237239
default:
238240
rv := reflect.ValueOf(v)
239-
if rv.Kind() == reflect.Ptr {
241+
if !rv.IsValid() {
242+
return protocolBufferFromNull(), nil
243+
}
244+
if rv.Kind() == reflect.Pointer {
240245
if rv.IsNil() {
241-
return protocolBufferFromNull()
246+
return protocolBufferFromNull(), nil
242247
}
243248
return protocolBufferFromValue(rv.Elem().Interface())
244249
}
245-
return make([][]byte, 0)
250+
251+
switch rv.Kind() {
252+
case reflect.String:
253+
return protocolBufferFromString(rv.String(), true), nil
254+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
255+
return protocolBufferFromInt(rv.Int()), nil
256+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
257+
return protocolBufferFromInt(rv.Uint()), nil
258+
case reflect.Float32, reflect.Float64:
259+
return protocolBufferFromFloat(rv.Convert(reflect.TypeOf(float64(0))).Float()), nil
260+
case reflect.Bool:
261+
if rv.Bool() {
262+
return protocolBufferFromInt(1), nil
263+
}
264+
return protocolBufferFromInt(0), nil
265+
default:
266+
return nil, fmt.Errorf("unsupported parameter type %T", v)
267+
}
246268
}
247269
}
248270

@@ -263,14 +285,7 @@ func protocolBufferFromInt(v interface{}) [][]byte {
263285
}
264286

265287
func protocolBufferFromFloat(v interface{}) [][]byte {
266-
var f float64
267-
switch v := v.(type) {
268-
case float32:
269-
f = float64(v)
270-
case float64:
271-
f = v
272-
}
273-
return [][]byte{[]byte(fmt.Sprintf("%c%s ", CMD_FLOAT, strconv.FormatFloat(f, 'f', -1, 64)))}
288+
return [][]byte{[]byte(fmt.Sprintf("%c%s ", CMD_FLOAT, strconv.FormatFloat(v.(float64), 'f', -1, 64)))}
274289
}
275290

276291
// func protocolBufferFromFloat(v interface{}) [][]byte {
@@ -386,7 +401,11 @@ func (this *SQCloud) sendArray(command string, values []interface{}) (int, error
386401
// convert values to buffers encoded with whe sqlitecloud protocol
387402
buffers := [][]byte{protocolBufferFromString(command, true)[0]}
388403
for _, v := range values {
389-
buffers = append(buffers, protocolBufferFromValue(v)...)
404+
valueBuffers, err := protocolBufferFromValue(v)
405+
if err != nil {
406+
return 0, err
407+
}
408+
buffers = append(buffers, valueBuffers...)
390409
}
391410

392411
// calculate the array header

chunk_internal_test.go

Lines changed: 125 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,158 @@
11
package sqlitecloud
22

33
import (
4+
"fmt"
5+
"strings"
46
"testing"
57
)
68

9+
type testStringEnum string
10+
type testIntEnum int
11+
712
func TestProtocolBufferFromValue(t *testing.T) {
13+
type unsupported struct{}
14+
intVal := 42
15+
strVal := "hello"
16+
817
tests := []struct {
9-
name string
10-
value interface{}
11-
wantLen int // expected number of []byte buffers returned
12-
wantType byte
18+
name string
19+
value interface{}
20+
wantLen int
21+
wantType byte
22+
wantError bool
1323
}{
14-
// Basic types
15-
{"nil", nil, 1, CMD_NULL},
16-
{"string", "hello", 1, CMD_ZEROSTRING},
17-
{"int", int(42), 1, CMD_INT},
18-
{"int8", int8(8), 1, CMD_INT},
19-
{"int16", int16(16), 1, CMD_INT},
20-
{"int32", int32(32), 1, CMD_INT},
21-
{"int64", int64(64), 1, CMD_INT},
22-
{"float32", float32(3.14), 1, CMD_FLOAT},
23-
{"float64", float64(2.71), 1, CMD_FLOAT},
24-
{"[]byte", []byte("blob"), 2, CMD_BLOB}, // header + data
25-
26-
// Unsigned integers
27-
{"uint", uint(1), 1, CMD_INT},
28-
{"uint8", uint8(1), 1, CMD_INT},
29-
{"uint16", uint16(1), 1, CMD_INT},
30-
{"uint32", uint32(1), 1, CMD_INT},
31-
{"uint64", uint64(1), 1, CMD_INT},
32-
33-
// Pointer types (dereferenced)
34-
{"*int", intPtr(42), 1, CMD_INT},
35-
{"*string", strPtr("hello"), 1, CMD_ZEROSTRING},
36-
{"*int nil", (*int)(nil), 1, CMD_NULL},
37-
{"*string nil", (*string)(nil), 1, CMD_NULL},
38-
39-
// Unsupported types still return empty buffers
40-
{"bool", true, 0, 0},
24+
{"nil", nil, 1, CMD_NULL, false},
25+
{"string", "hello", 1, CMD_ZEROSTRING, false},
26+
{"int", int(42), 1, CMD_INT, false},
27+
{"int8", int8(8), 1, CMD_INT, false},
28+
{"int16", int16(16), 1, CMD_INT, false},
29+
{"int32", int32(32), 1, CMD_INT, false},
30+
{"int64", int64(64), 1, CMD_INT, false},
31+
{"uint", uint(1), 1, CMD_INT, false},
32+
{"uint8", uint8(1), 1, CMD_INT, false},
33+
{"uint16", uint16(1), 1, CMD_INT, false},
34+
{"uint32", uint32(1), 1, CMD_INT, false},
35+
{"uint64", uint64(1), 1, CMD_INT, false},
36+
{"float32", float32(3.14), 1, CMD_FLOAT, false},
37+
{"float64", float64(2.71), 1, CMD_FLOAT, false},
38+
{"[]byte", []byte("blob"), 2, CMD_BLOB, false},
39+
{"bool true", true, 1, CMD_INT, false},
40+
{"bool false", false, 1, CMD_INT, false},
41+
{"*int", &intVal, 1, CMD_INT, false},
42+
{"*string", &strVal, 1, CMD_ZEROSTRING, false},
43+
{"*int nil", (*int)(nil), 1, CMD_NULL, false},
44+
{"*string nil", (*string)(nil), 1, CMD_NULL, false},
45+
{"unsupported", unsupported{}, 0, 0, true},
4146
}
4247

4348
for _, tt := range tests {
4449
t.Run(tt.name, func(t *testing.T) {
45-
buffers := protocolBufferFromValue(tt.value)
50+
buffers, err := protocolBufferFromValue(tt.value)
51+
if tt.wantError {
52+
if err == nil {
53+
t.Fatalf("expected error, got nil")
54+
}
55+
return
56+
}
57+
if err != nil {
58+
t.Fatalf("unexpected error: %v", err)
59+
}
4660
if len(buffers) != tt.wantLen {
47-
t.Errorf("protocolBufferFromValue(%T(%v)): got %d buffers, want %d", tt.value, tt.value, len(buffers), tt.wantLen)
61+
t.Fatalf("got %d buffers, want %d", len(buffers), tt.wantLen)
4862
}
49-
if tt.wantLen > 0 && len(buffers) > 0 {
50-
if buffers[0][0] != tt.wantType {
51-
t.Errorf("protocolBufferFromValue(%T(%v)): got type %c, want %c", tt.value, tt.value, buffers[0][0], tt.wantType)
52-
}
63+
if tt.wantLen > 0 && buffers[0][0] != tt.wantType {
64+
t.Fatalf("got first buffer type %q, want %q", buffers[0][0], tt.wantType)
5365
}
5466
})
5567
}
5668
}
5769

58-
func TestProtocolBufferFromValueMixedArray(t *testing.T) {
59-
// Simulates the loop in sendArray: builds buffers from a mixed values slice
60-
// and checks that the number of buffer groups matches the number of values.
61-
pInt := intPtr(99)
62-
values := []interface{}{
63-
"hello", // string -> 1 buffer
64-
int(42), // int -> 1 buffer
65-
nil, // nil -> 1 buffer
66-
pInt, // *int -> 1 buffer (dereferenced to int)
67-
float64(3), // float64 -> 1 buffer
68-
uint(7), // uint -> 1 buffer
69-
[]byte("x"), // []byte -> 2 buffers (header+data)
70+
func TestProtocolBufferFromValueSupportsStringAlias(t *testing.T) {
71+
val := testStringEnum("active")
72+
buffers, err := protocolBufferFromValue(val)
73+
if err != nil {
74+
t.Fatalf("unexpected error: %v", err)
7075
}
7176

72-
// Count how many values produce at least one buffer
73-
buffersPerValue := make([]int, len(values))
74-
totalBuffers := 0
75-
missingValues := 0
77+
if len(buffers) != 1 {
78+
t.Fatalf("expected 1 buffer, got %d", len(buffers))
79+
}
80+
got := string(buffers[0])
81+
want := fmt.Sprintf("%c%d %s\x00", CMD_ZEROSTRING, len("active")+1, "active")
82+
if got != want {
83+
t.Fatalf("unexpected encoded value: want %q got %q", want, got)
84+
}
85+
}
7686

77-
for i, v := range values {
78-
bufs := protocolBufferFromValue(v)
79-
buffersPerValue[i] = len(bufs)
80-
totalBuffers += len(bufs)
81-
if len(bufs) == 0 {
82-
missingValues++
83-
t.Errorf("value[%d] (%T = %v) produced 0 buffers — will be silently dropped", i, v, v)
84-
}
87+
func TestProtocolBufferFromValueSupportsIntAliasPointer(t *testing.T) {
88+
raw := testIntEnum(7)
89+
buffers, err := protocolBufferFromValue(&raw)
90+
if err != nil {
91+
t.Fatalf("unexpected error: %v", err)
8592
}
8693

87-
if missingValues > 0 {
88-
t.Errorf("%d out of %d values produced no buffers and will be missing from the protocol message", missingValues, len(values))
94+
if len(buffers) != 1 {
95+
t.Fatalf("expected 1 buffer, got %d", len(buffers))
8996
}
97+
got := string(buffers[0])
98+
want := fmt.Sprintf("%c%d ", CMD_INT, 7)
99+
if got != want {
100+
t.Fatalf("unexpected encoded value: want %q got %q", want, got)
101+
}
102+
}
90103

91-
// Reproduce the exact loop from sendArray
92-
buffers := [][]byte{}
93-
for _, v := range values {
94-
buffers = append(buffers, protocolBufferFromValue(v)...)
104+
func TestProtocolBufferFromValueSupportsFloat32(t *testing.T) {
105+
buffers, err := protocolBufferFromValue(float32(2.5))
106+
if err != nil {
107+
t.Fatalf("unexpected error: %v", err)
95108
}
96109

97-
t.Logf("values count: %d, total buffers: %d, buffers per value: %v", len(values), len(buffers), buffersPerValue)
110+
if len(buffers) != 1 {
111+
t.Fatalf("expected 1 buffer, got %d", len(buffers))
112+
}
113+
got := string(buffers[0])
114+
if !strings.HasPrefix(got, fmt.Sprintf("%c", CMD_FLOAT)) {
115+
t.Fatalf("expected float buffer prefix, got %q", got)
116+
}
117+
}
118+
119+
func TestProtocolBufferFromValueUnsupportedTypeReturnsError(t *testing.T) {
120+
type unsupported struct {
121+
Name string
122+
}
98123

99-
// Every value must produce at least 1 buffer ([]byte produces 2)
100-
expectedMinBuffers := len(values)
101-
if len(buffers) < expectedMinBuffers {
102-
t.Errorf("buffers array has %d elements, expected at least %d (one per value). %d values were silently dropped.",
103-
len(buffers), expectedMinBuffers, missingValues)
124+
_, err := protocolBufferFromValue(unsupported{Name: "x"})
125+
if err == nil {
126+
t.Fatalf("expected error for unsupported type")
104127
}
105128
}
106129

107-
// helpers
108-
func intPtr(v int) *int { return &v }
109-
func strPtr(v string) *string { return &v }
130+
func TestProtocolBufferFromValueMixedArrayNoSilentDrops(t *testing.T) {
131+
pInt := 99
132+
values := []interface{}{
133+
"hello",
134+
int(42),
135+
nil,
136+
&pInt,
137+
float64(3),
138+
uint(7),
139+
[]byte("x"),
140+
true,
141+
}
142+
143+
buffers := [][]byte{}
144+
for i, v := range values {
145+
valueBuffers, err := protocolBufferFromValue(v)
146+
if err != nil {
147+
t.Fatalf("unexpected error at index %d (%T): %v", i, v, err)
148+
}
149+
if len(valueBuffers) == 0 {
150+
t.Fatalf("value at index %d produced zero buffers", i)
151+
}
152+
buffers = append(buffers, valueBuffers...)
153+
}
154+
155+
if len(buffers) < len(values) {
156+
t.Fatalf("got %d total buffers, expected at least %d", len(buffers), len(values))
157+
}
158+
}

0 commit comments

Comments
 (0)