|
1 | 1 | package risingwave |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "encoding/hex" |
5 | | - "fmt" |
6 | | - "strconv" |
7 | | - "strings" |
8 | | - "time" |
| 4 | + "encoding/hex" |
| 5 | + "fmt" |
| 6 | + "strconv" |
| 7 | + "strings" |
| 8 | + "time" |
9 | 9 |
|
10 | | - "github.com/golang/protobuf/protoc-gen-go/descriptor" |
11 | | - "github.com/jhump/protoreflect/desc" |
12 | | - sql2 "github.com/streamingfast/substreams-sink-sql/db_proto/sql" |
13 | | - "github.com/streamingfast/substreams-sink-sql/proto" |
14 | | - "google.golang.org/protobuf/types/known/timestamppb" |
| 10 | + "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| 11 | + "github.com/jhump/protoreflect/desc" |
| 12 | + sql2 "github.com/streamingfast/substreams-sink-sql/db_proto/sql" |
| 13 | + "github.com/streamingfast/substreams-sink-sql/proto" |
| 14 | + "google.golang.org/protobuf/types/known/timestamppb" |
15 | 15 | ) |
16 | 16 |
|
17 | 17 | type DataType string |
@@ -100,52 +100,60 @@ func SupportsSemanticType(semanticType sql2.SemanticType) bool { |
100 | 100 | } |
101 | 101 |
|
102 | 102 | func MapFieldType(fd *desc.FieldDescriptor) DataType { |
103 | | - // Check for semantic type annotation first |
104 | | - semanticType, _, hasSemanticType := proto.SemanticTypeInfo(fd) |
105 | | - if hasSemanticType { |
106 | | - if sqlType, supported := MapSemanticType(sql2.SemanticType(semanticType)); supported { |
107 | | - return DataType(sqlType) |
108 | | - } |
109 | | - // Fall through to default mapping if semantic type not supported |
110 | | - } |
| 103 | + // First, attempt semantic mapping to get the base type |
| 104 | + if semanticType, _, hasSemanticType := proto.SemanticTypeInfo(fd); hasSemanticType { |
| 105 | + if sqlType, supported := MapSemanticType(sql2.SemanticType(semanticType)); supported { |
| 106 | + base := DataType(sqlType) |
| 107 | + if fd.IsRepeated() { |
| 108 | + return DataType(fmt.Sprintf("%s[]", base)) |
| 109 | + } |
| 110 | + return base |
| 111 | + } |
| 112 | + // Fall through to default mapping if semantic type not supported |
| 113 | + } |
111 | 114 |
|
112 | | - // Default protobuf type mapping |
113 | | - t := fd.GetType() |
114 | | - switch t { |
115 | | - case descriptor.FieldDescriptorProto_TYPE_MESSAGE: |
116 | | - switch fd.GetMessageType().GetFullyQualifiedName() { |
117 | | - case "google.protobuf.Timestamp": |
118 | | - return TypeTimestamptz // Use timestamptz for protobuf timestamps |
119 | | - default: |
120 | | - panic(fmt.Sprintf("Message type not supported: %s", fd.GetMessageType().GetFullyQualifiedName())) |
121 | | - } |
122 | | - case descriptor.FieldDescriptorProto_TYPE_BOOL: |
123 | | - return TypeBool |
124 | | - case descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32, descriptor.FieldDescriptorProto_TYPE_SFIXED32: |
125 | | - return TypeInteger |
126 | | - case descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64, descriptor.FieldDescriptorProto_TYPE_SFIXED64: |
127 | | - return TypeBigInt |
128 | | - case descriptor.FieldDescriptorProto_TYPE_UINT64, descriptor.FieldDescriptorProto_TYPE_FIXED64: |
129 | | - return TypeNumeric // Use NUMERIC for large unsigned integers |
130 | | - case descriptor.FieldDescriptorProto_TYPE_UINT32, descriptor.FieldDescriptorProto_TYPE_FIXED32: |
131 | | - return TypeBigInt // Use BIGINT for 32-bit unsigned (to avoid overflow) |
132 | | - case descriptor.FieldDescriptorProto_TYPE_FLOAT: |
133 | | - return TypeReal // Use REAL for single precision |
134 | | - case descriptor.FieldDescriptorProto_TYPE_DOUBLE: |
135 | | - return TypeDouble |
136 | | - case descriptor.FieldDescriptorProto_TYPE_STRING: |
137 | | - return TypeVarchar |
138 | | - case descriptor.FieldDescriptorProto_TYPE_BYTES: |
139 | | - return TypeBytea // Use BYTEA for binary data |
140 | | - case descriptor.FieldDescriptorProto_TYPE_ENUM: |
141 | | - return TypeVarchar // Store enums as varchar |
142 | | - default: |
143 | | - panic(fmt.Sprintf("unsupported type: %s", t)) |
144 | | - } |
| 115 | + // Default protobuf type mapping to determine the base type |
| 116 | + var baseType DataType |
| 117 | + switch fd.GetType() { |
| 118 | + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: |
| 119 | + switch fd.GetMessageType().GetFullyQualifiedName() { |
| 120 | + case "google.protobuf.Timestamp": |
| 121 | + baseType = TypeTimestamptz // Use timestamptz for protobuf timestamps |
| 122 | + default: |
| 123 | + panic(fmt.Sprintf("Message type not supported: %s", fd.GetMessageType().GetFullyQualifiedName())) |
| 124 | + } |
| 125 | + case descriptor.FieldDescriptorProto_TYPE_BOOL: |
| 126 | + baseType = TypeBool |
| 127 | + case descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32, descriptor.FieldDescriptorProto_TYPE_SFIXED32: |
| 128 | + baseType = TypeInteger |
| 129 | + case descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64, descriptor.FieldDescriptorProto_TYPE_SFIXED64: |
| 130 | + baseType = TypeBigInt |
| 131 | + case descriptor.FieldDescriptorProto_TYPE_UINT64, descriptor.FieldDescriptorProto_TYPE_FIXED64: |
| 132 | + baseType = TypeNumeric // Use NUMERIC for large unsigned integers |
| 133 | + case descriptor.FieldDescriptorProto_TYPE_UINT32, descriptor.FieldDescriptorProto_TYPE_FIXED32: |
| 134 | + baseType = TypeBigInt // Use BIGINT for 32-bit unsigned (to avoid overflow) |
| 135 | + case descriptor.FieldDescriptorProto_TYPE_FLOAT: |
| 136 | + baseType = TypeReal // Use REAL for single precision |
| 137 | + case descriptor.FieldDescriptorProto_TYPE_DOUBLE: |
| 138 | + baseType = TypeDouble |
| 139 | + case descriptor.FieldDescriptorProto_TYPE_STRING: |
| 140 | + baseType = TypeVarchar |
| 141 | + case descriptor.FieldDescriptorProto_TYPE_BYTES: |
| 142 | + baseType = TypeBytea // Use BYTEA for binary data |
| 143 | + case descriptor.FieldDescriptorProto_TYPE_ENUM: |
| 144 | + baseType = TypeVarchar // Store enums as varchar |
| 145 | + default: |
| 146 | + panic(fmt.Sprintf("unsupported type: %s", fd.GetType())) |
| 147 | + } |
| 148 | + |
| 149 | + if fd.IsRepeated() { |
| 150 | + return DataType(fmt.Sprintf("%s[]", baseType)) |
| 151 | + } |
| 152 | + return baseType |
145 | 153 | } |
146 | 154 |
|
147 | 155 | func ValueToString(value any) (s string) { |
148 | | - switch v := value.(type) { |
| 156 | + switch v := value.(type) { |
149 | 157 | case string: |
150 | 158 | s = "'" + strings.ReplaceAll(strings.ReplaceAll(v, "'", "''"), "\\", "\\\\") + "'" |
151 | 159 | case int64: |
@@ -173,13 +181,20 @@ func ValueToString(value any) (s string) { |
173 | 181 | case time.Time: |
174 | 182 | // Use RFC3339 format for timestamps |
175 | 183 | s = "'" + v.Format(time.RFC3339) + "'" |
176 | | - case *timestamppb.Timestamp: |
177 | | - // Convert protobuf timestamp to timestamptz format |
178 | | - s = "'" + v.AsTime().Format(time.RFC3339) + "'" |
179 | | - default: |
180 | | - panic(fmt.Sprintf("unsupported type: %T", v)) |
181 | | - } |
182 | | - return |
| 184 | + case *timestamppb.Timestamp: |
| 185 | + // Convert protobuf timestamp to timestamptz format |
| 186 | + s = "'" + v.AsTime().Format(time.RFC3339) + "'" |
| 187 | + // Handle array types for RisingWave (PostgreSQL-compatible array literals) |
| 188 | + case []interface{}: |
| 189 | + var elements []string |
| 190 | + for _, elem := range v { |
| 191 | + elements = append(elements, ValueToString(elem)) |
| 192 | + } |
| 193 | + s = "{" + strings.Join(elements, ",") + "}" |
| 194 | + default: |
| 195 | + panic(fmt.Sprintf("unsupported type: %T", v)) |
| 196 | + } |
| 197 | + return |
183 | 198 | } |
184 | 199 |
|
185 | 200 | // ConvertSemanticValue converts a value according to semantic type and format hint for RisingWave |
|
0 commit comments