Skip to content

Commit f7f59da

Browse files
committed
feat: Add support for repeated scalar fields as array columns in RisingWave
1 parent 67b186a commit f7f59da

6 files changed

Lines changed: 108 additions & 68 deletions

File tree

db_proto/sql/postgres/types_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestValueToString(t *testing.T) {
2828
{"int", int(789), "789"},
2929

3030
// Unsigned integer values
31-
{"uint64", uint64(123), "'123'"},
31+
{"uint64", uint64(123), "123"},
3232
{"uint32", uint32(456), "456"},
3333
{"uint", uint(789), "789"},
3434

db_proto/sql/risingwave/accumulator_inserter.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,16 @@ func createInsertFromDescriptorAcc(table *schema.Table, dialect sql2.Dialect) (s
8484
continue
8585
}
8686

87-
if field.IsRepeated || field.IsExtension {
87+
if field.IsExtension { // not a direct child
8888
continue
8989
}
90+
if field.IsRepeated {
91+
// Skip repeated messages, but allow repeated scalars (arrays)
92+
if field.IsMessage {
93+
continue
94+
}
95+
// Allow repeated scalar fields to be processed as arrays
96+
}
9097
fieldNames = append(fieldNames, field.QuotedName())
9198
}
9299

db_proto/sql/risingwave/dialect.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,12 @@ func (d *DialectRisingwave) createTable(table *schema.Table) error {
136136

137137
fieldQuotedName := f.QuotedName()
138138

139-
// Skip repeated fields (not supported in SQL)
139+
// Allow repeated scalar fields as arrays; skip repeated messages
140140
if f.IsRepeated {
141-
continue
141+
if f.IsMessage {
142+
continue
143+
}
144+
// Repeated scalars proceed and are typed as arrays by MapFieldType
142145
}
143146

144147
// Skip message fields that don't map to simple columns

db_proto/sql/risingwave/dialect_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func TestDialectRisingwave_CreateTable_WithUniqueConstraint(t *testing.T) {
229229
assert.Contains(t, sql, `"name" CHARACTER VARYING`)
230230
}
231231

232-
func TestDialectRisingwave_CreateTable_SkipsRepeatedFields(t *testing.T) {
232+
func TestDialectRisingwave_CreateTable_HandlesRepeatedScalarsAsArrays(t *testing.T) {
233233
logger := zap.NewNop()
234234

235235
tagsField := createMockFieldDescriptor("tags", descriptor.FieldDescriptorProto_TYPE_STRING)
@@ -247,9 +247,10 @@ func TestDialectRisingwave_CreateTable_SkipsRepeatedFields(t *testing.T) {
247247
d, err := NewDialectRisingwave("public", tableRegistry, logger)
248248
require.NoError(t, err)
249249

250-
sql := d.CreateTableSql["users"]
251-
assert.NotContains(t, sql, "tags")
252-
assert.Contains(t, sql, `"name" CHARACTER VARYING`)
250+
sql := d.CreateTableSql["users"]
251+
// Repeated scalar field should be present as an array type
252+
assert.Contains(t, sql, `"tags" CHARACTER VARYING[]`)
253+
assert.Contains(t, sql, `"name" CHARACTER VARYING`)
253254
}
254255

255256
func TestDialectRisingwave_CreateTable_PreventsDuplicateColumns(t *testing.T) {

db_proto/sql/risingwave/row_inserter.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,16 @@ func createInsertFromDescriptor(table *schema.Table, dialect sql2.Dialect) (stri
110110
if field.Name == returningField {
111111
continue
112112
}
113-
if field.IsRepeated || field.IsExtension {
113+
if field.IsExtension { // not a direct child
114114
continue
115115
}
116+
if field.IsRepeated {
117+
// Skip repeated messages, but allow repeated scalars (arrays)
118+
if field.IsMessage {
119+
continue
120+
}
121+
// Allow repeated scalar fields to be processed as arrays
122+
}
116123
fieldCount++
117124
fieldNames = append(fieldNames, field.QuotedName())
118125
placeholders = append(placeholders, fmt.Sprintf("$%d", fieldCount))
@@ -138,6 +145,13 @@ func (i *RowInserter) insert(table string, values []any, database *Database) err
138145
values[i] = base64.StdEncoding.EncodeToString(v)
139146
case *timestamppb.Timestamp:
140147
values[i] = "'" + v.AsTime().Format(time.RFC3339) + "'"
148+
case []interface{}:
149+
// Convert to PostgreSQL/RisingWave array literal: {elem1,elem2,...}
150+
var elements []string
151+
for _, elem := range v {
152+
elements = append(elements, ValueToString(elem))
153+
}
154+
values[i] = "{" + strings.Join(elements, ",") + "}"
141155
}
142156
}
143157

db_proto/sql/risingwave/types.go

Lines changed: 74 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
package risingwave
22

33
import (
4-
"encoding/hex"
5-
"fmt"
6-
"strconv"
7-
"strings"
8-
"time"
4+
"encoding/hex"
5+
"fmt"
6+
"strconv"
7+
"strings"
8+
"time"
99

10-
"github.com/golang/protobuf/protoc-gen-go/descriptor"
11-
"github.com/jhump/protoreflect/desc"
12-
sql2 "github.com/streamingfast/substreams-sink-sql/db_proto/sql"
13-
"github.com/streamingfast/substreams-sink-sql/proto"
14-
"google.golang.org/protobuf/types/known/timestamppb"
10+
"github.com/golang/protobuf/protoc-gen-go/descriptor"
11+
"github.com/jhump/protoreflect/desc"
12+
sql2 "github.com/streamingfast/substreams-sink-sql/db_proto/sql"
13+
"github.com/streamingfast/substreams-sink-sql/proto"
14+
"google.golang.org/protobuf/types/known/timestamppb"
1515
)
1616

1717
type DataType string
@@ -100,52 +100,60 @@ func SupportsSemanticType(semanticType sql2.SemanticType) bool {
100100
}
101101

102102
func MapFieldType(fd *desc.FieldDescriptor) DataType {
103-
// Check for semantic type annotation first
104-
semanticType, _, hasSemanticType := proto.SemanticTypeInfo(fd)
105-
if hasSemanticType {
106-
if sqlType, supported := MapSemanticType(sql2.SemanticType(semanticType)); supported {
107-
return DataType(sqlType)
108-
}
109-
// Fall through to default mapping if semantic type not supported
110-
}
103+
// First, attempt semantic mapping to get the base type
104+
if semanticType, _, hasSemanticType := proto.SemanticTypeInfo(fd); hasSemanticType {
105+
if sqlType, supported := MapSemanticType(sql2.SemanticType(semanticType)); supported {
106+
base := DataType(sqlType)
107+
if fd.IsRepeated() {
108+
return DataType(fmt.Sprintf("%s[]", base))
109+
}
110+
return base
111+
}
112+
// Fall through to default mapping if semantic type not supported
113+
}
111114

112-
// Default protobuf type mapping
113-
t := fd.GetType()
114-
switch t {
115-
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
116-
switch fd.GetMessageType().GetFullyQualifiedName() {
117-
case "google.protobuf.Timestamp":
118-
return TypeTimestamptz // Use timestamptz for protobuf timestamps
119-
default:
120-
panic(fmt.Sprintf("Message type not supported: %s", fd.GetMessageType().GetFullyQualifiedName()))
121-
}
122-
case descriptor.FieldDescriptorProto_TYPE_BOOL:
123-
return TypeBool
124-
case descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32, descriptor.FieldDescriptorProto_TYPE_SFIXED32:
125-
return TypeInteger
126-
case descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64, descriptor.FieldDescriptorProto_TYPE_SFIXED64:
127-
return TypeBigInt
128-
case descriptor.FieldDescriptorProto_TYPE_UINT64, descriptor.FieldDescriptorProto_TYPE_FIXED64:
129-
return TypeNumeric // Use NUMERIC for large unsigned integers
130-
case descriptor.FieldDescriptorProto_TYPE_UINT32, descriptor.FieldDescriptorProto_TYPE_FIXED32:
131-
return TypeBigInt // Use BIGINT for 32-bit unsigned (to avoid overflow)
132-
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
133-
return TypeReal // Use REAL for single precision
134-
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
135-
return TypeDouble
136-
case descriptor.FieldDescriptorProto_TYPE_STRING:
137-
return TypeVarchar
138-
case descriptor.FieldDescriptorProto_TYPE_BYTES:
139-
return TypeBytea // Use BYTEA for binary data
140-
case descriptor.FieldDescriptorProto_TYPE_ENUM:
141-
return TypeVarchar // Store enums as varchar
142-
default:
143-
panic(fmt.Sprintf("unsupported type: %s", t))
144-
}
115+
// Default protobuf type mapping to determine the base type
116+
var baseType DataType
117+
switch fd.GetType() {
118+
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
119+
switch fd.GetMessageType().GetFullyQualifiedName() {
120+
case "google.protobuf.Timestamp":
121+
baseType = TypeTimestamptz // Use timestamptz for protobuf timestamps
122+
default:
123+
panic(fmt.Sprintf("Message type not supported: %s", fd.GetMessageType().GetFullyQualifiedName()))
124+
}
125+
case descriptor.FieldDescriptorProto_TYPE_BOOL:
126+
baseType = TypeBool
127+
case descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32, descriptor.FieldDescriptorProto_TYPE_SFIXED32:
128+
baseType = TypeInteger
129+
case descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64, descriptor.FieldDescriptorProto_TYPE_SFIXED64:
130+
baseType = TypeBigInt
131+
case descriptor.FieldDescriptorProto_TYPE_UINT64, descriptor.FieldDescriptorProto_TYPE_FIXED64:
132+
baseType = TypeNumeric // Use NUMERIC for large unsigned integers
133+
case descriptor.FieldDescriptorProto_TYPE_UINT32, descriptor.FieldDescriptorProto_TYPE_FIXED32:
134+
baseType = TypeBigInt // Use BIGINT for 32-bit unsigned (to avoid overflow)
135+
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
136+
baseType = TypeReal // Use REAL for single precision
137+
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
138+
baseType = TypeDouble
139+
case descriptor.FieldDescriptorProto_TYPE_STRING:
140+
baseType = TypeVarchar
141+
case descriptor.FieldDescriptorProto_TYPE_BYTES:
142+
baseType = TypeBytea // Use BYTEA for binary data
143+
case descriptor.FieldDescriptorProto_TYPE_ENUM:
144+
baseType = TypeVarchar // Store enums as varchar
145+
default:
146+
panic(fmt.Sprintf("unsupported type: %s", fd.GetType()))
147+
}
148+
149+
if fd.IsRepeated() {
150+
return DataType(fmt.Sprintf("%s[]", baseType))
151+
}
152+
return baseType
145153
}
146154

147155
func ValueToString(value any) (s string) {
148-
switch v := value.(type) {
156+
switch v := value.(type) {
149157
case string:
150158
s = "'" + strings.ReplaceAll(strings.ReplaceAll(v, "'", "''"), "\\", "\\\\") + "'"
151159
case int64:
@@ -173,13 +181,20 @@ func ValueToString(value any) (s string) {
173181
case time.Time:
174182
// Use RFC3339 format for timestamps
175183
s = "'" + v.Format(time.RFC3339) + "'"
176-
case *timestamppb.Timestamp:
177-
// Convert protobuf timestamp to timestamptz format
178-
s = "'" + v.AsTime().Format(time.RFC3339) + "'"
179-
default:
180-
panic(fmt.Sprintf("unsupported type: %T", v))
181-
}
182-
return
184+
case *timestamppb.Timestamp:
185+
// Convert protobuf timestamp to timestamptz format
186+
s = "'" + v.AsTime().Format(time.RFC3339) + "'"
187+
// Handle array types for RisingWave (PostgreSQL-compatible array literals)
188+
case []interface{}:
189+
var elements []string
190+
for _, elem := range v {
191+
elements = append(elements, ValueToString(elem))
192+
}
193+
s = "{" + strings.Join(elements, ",") + "}"
194+
default:
195+
panic(fmt.Sprintf("unsupported type: %T", v))
196+
}
197+
return
183198
}
184199

185200
// ConvertSemanticValue converts a value according to semantic type and format hint for RisingWave

0 commit comments

Comments
 (0)