Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net"
"reflect"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -221,20 +222,49 @@ func (this *Value) readBufferAt(chunk *Chunk, offset uint64) (uint64, error) {
return 0, errors.New("Unsuported type")
}

func protocolBufferFromValue(v interface{}) [][]byte {
func protocolBufferFromValue(v interface{}) ([][]byte, error) {
switch v := v.(type) {
case nil:
return protocolBufferFromNull()
case int, int8, int16, int32, int64:
return protocolBufferFromInt(v)
case float32, float64:
return protocolBufferFromFloat(v)
return protocolBufferFromNull(), nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return protocolBufferFromInt(v), nil
case float32:
return protocolBufferFromFloat(float64(v)), nil
case float64:
return protocolBufferFromFloat(v), nil
case string:
return protocolBufferFromString(v, true)
return protocolBufferFromString(v, true), nil
case []byte:
return protocolBufferFromBytes(v)
return protocolBufferFromBytes(v), nil
default:
return make([][]byte, 0)
rv := reflect.ValueOf(v)
if !rv.IsValid() {
return protocolBufferFromNull(), nil
}
if rv.Kind() == reflect.Pointer {
if rv.IsNil() {
return protocolBufferFromNull(), nil
}
return protocolBufferFromValue(rv.Elem().Interface())
}

switch rv.Kind() {
case reflect.String:
return protocolBufferFromString(rv.String(), true), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return protocolBufferFromInt(rv.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return protocolBufferFromInt(rv.Uint()), nil
case reflect.Float32, reflect.Float64:
return protocolBufferFromFloat(rv.Convert(reflect.TypeOf(float64(0))).Float()), nil
case reflect.Bool:
if rv.Bool() {
return protocolBufferFromInt(1), nil
}
return protocolBufferFromInt(0), nil
default:
return nil, fmt.Errorf("unsupported parameter type %T", v)
}
}
}

Expand Down Expand Up @@ -371,7 +401,11 @@ func (this *SQCloud) sendArray(command string, values []interface{}) (int, error
// convert values to buffers encoded with whe sqlitecloud protocol
buffers := [][]byte{protocolBufferFromString(command, true)[0]}
for _, v := range values {
buffers = append(buffers, protocolBufferFromValue(v)...)
valueBuffers, err := protocolBufferFromValue(v)
if err != nil {
return 0, err
}
buffers = append(buffers, valueBuffers...)
}

// calculate the array header
Expand Down
70 changes: 70 additions & 0 deletions chunk_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package sqlitecloud

import (
"fmt"
"strings"
"testing"
)

type testStringEnum string
type testIntEnum int

func TestProtocolBufferFromValueSupportsStringAlias(t *testing.T) {
val := testStringEnum("active")
buffers, err := protocolBufferFromValue(val)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(buffers) != 1 {
t.Fatalf("expected 1 buffer, got %d", len(buffers))
}
got := string(buffers[0])
want := fmt.Sprintf("%c%d %s\x00", CMD_ZEROSTRING, len("active")+1, "active")
if got != want {
t.Fatalf("unexpected encoded value: want %q got %q", want, got)
}
}

func TestProtocolBufferFromValueSupportsIntAliasPointer(t *testing.T) {
raw := testIntEnum(7)
buffers, err := protocolBufferFromValue(&raw)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(buffers) != 1 {
t.Fatalf("expected 1 buffer, got %d", len(buffers))
}
got := string(buffers[0])
want := fmt.Sprintf("%c%d ", CMD_INT, 7)
if got != want {
t.Fatalf("unexpected encoded value: want %q got %q", want, got)
}
}

func TestProtocolBufferFromValueSupportsFloat32(t *testing.T) {
buffers, err := protocolBufferFromValue(float32(2.5))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(buffers) != 1 {
t.Fatalf("expected 1 buffer, got %d", len(buffers))
}
got := string(buffers[0])
if !strings.HasPrefix(got, fmt.Sprintf("%c", CMD_FLOAT)) {
t.Fatalf("expected float buffer prefix, got %q", got)
}
}

func TestProtocolBufferFromValueUnsupportedTypeReturnsError(t *testing.T) {
type unsupported struct {
Name string
}

_, err := protocolBufferFromValue(unsupported{Name: "x"})
if err == nil {
t.Fatalf("expected error for unsupported type")
}
}