diff --git a/common/types/map_test.go b/common/types/map_test.go index b16422c75..d0e8bf32c 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -892,7 +892,7 @@ func TestProtoMapConvertToNative(t *testing.T) { if mapVal4.Equal(mapVal) != True || mapVal.Equal(mapVal4) != True { t.Errorf("mapVal4.Equal(mapVal) returned false, wanted true") } - convMap, err = mapVal.ConvertToNative(reflect.TypeOf(&pb.Map{})) + convMap, err = mapVal.ConvertToNative(reflect.TypeOf(map[string]string{})) if err != nil { t.Fatalf("mapVal.ConvertToNative() failed: %v", err) } diff --git a/common/types/pb/type.go b/common/types/pb/type.go index 171494f07..74c005792 100644 --- a/common/types/pb/type.go +++ b/common/types/pb/type.go @@ -20,6 +20,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" dynamicpb "google.golang.org/protobuf/types/dynamicpb" @@ -254,30 +255,108 @@ func (fd *FieldDescription) GetFrom(target any) (any, error) { } pbRef := v.ProtoReflect() pbDesc := pbRef.Descriptor() - var fieldVal any + var fieldVal protoreflect.Value if pbDesc == fd.desc.ContainingMessage() { // When the target protobuf shares the same message descriptor instance as the field // descriptor, use the cached field descriptor value. - fieldVal = pbRef.Get(fd.desc).Interface() + fieldVal = pbRef.Get(fd.desc) } else { // Otherwise, fallback to a dynamic lookup of the field descriptor from the target // instance as an attempt to use the cached field descriptor will result in a panic. - fieldVal = pbRef.Get(pbDesc.Fields().ByName(protoreflect.Name(fd.Name()))).Interface() + fieldVal = pbRef.Get(pbDesc.Fields().ByName(protoreflect.Name(fd.Name()))) } - switch fv := fieldVal.(type) { + return fd.getNativeValue(fieldVal) +} + +func (fd *FieldDescription) getNativeType(v protoreflect.Value) (reflect.Type, error) { + switch fv := v.Interface().(type) { + case protoreflect.Message: + // Make sure to unwrap well-known protobuf types before returning. + unwrapped, _, err := fd.MaybeUnwrapDynamic(fv) + return reflect.TypeOf(unwrapped), err + case protoreflect.EnumNumber: + enumType, err := protoregistry.GlobalTypes.FindEnumByName(fd.desc.Enum().FullName()) + if err != nil { + return nil, err + } + return reflect.TypeOf(enumType.New(0)), nil + case protoreflect.List: + if fv == nil { + return nil, nil + } + + element := fv.NewElement() + et, err := fd.getNativeType(element) + if err != nil { + return nil, err + } + return reflect.SliceOf(et), nil + case protoreflect.Map: + vt, err := fd.getNativeType(fv.NewValue()) + if err != nil { + return nil, err + } + return reflect.MapOf(fd.KeyType.reflectType, vt), nil + default: + return reflect.TypeOf(fv), nil + } +} + +func (fd *FieldDescription) getNativeValue(v protoreflect.Value) (any, error) { + switch fv := v.Interface().(type) { // Fast-path return for primitive types. - case bool, []byte, float32, float64, int32, int64, string, uint32, uint64, protoreflect.List: + case bool, []byte, float32, float64, int32, int64, string, uint32, uint64: return fv, nil case protoreflect.EnumNumber: - return int64(fv), nil + enumType, err := protoregistry.GlobalTypes.FindEnumByName(fd.desc.Enum().FullName()) + if err != nil { + return nil, err + } + return enumType.New(fv), nil case protoreflect.Map: - // Return a wrapper around the protobuf-reflected Map types which carries additional - // information about the key and value definitions of the map. - return &Map{Map: fv, KeyType: fd.KeyType, ValueType: fd.ValueType}, nil + if fv == nil { + return nil, nil + } + + mapType, err := fd.getNativeType(v) + if err != nil { + return nil, err + } + + m := reflect.MakeMap(mapType) + fv.Range(func(mk protoreflect.MapKey, v protoreflect.Value) bool { + vv, err := fd.getNativeValue(v) + if err != nil { + return false + } + m.SetMapIndex(reflect.ValueOf(mk.Interface()), reflect.ValueOf(vv)) + return true + }) + return m.Interface(), nil case protoreflect.Message: // Make sure to unwrap well-known protobuf types before returning. unwrapped, _, err := fd.MaybeUnwrapDynamic(fv) return unwrapped, err + case protoreflect.List: + if fv == nil { + return nil, nil + } + + sliceType, err := fd.getNativeType(v) + if err != nil { + return nil, err + } + + slice := reflect.MakeSlice(sliceType, fv.Len(), fv.Len()) + + for i := 0; i < fv.Len(); i++ { + elementVal, err := fd.getNativeValue(fv.Get(i)) + if err != nil { + return nil, err + } + slice.Index(i).Set(reflect.ValueOf(elementVal)) + } + return slice.Interface(), nil default: return fv, nil } diff --git a/common/types/pb/type_test.go b/common/types/pb/type_test.go index bf2ad58d6..6162b7442 100644 --- a/common/types/pb/type_test.go +++ b/common/types/pb/type_test.go @@ -162,6 +162,31 @@ func TestFieldDescriptionGetFrom(t *testing.T) { SingleStruct: jsonStruct(t, map[string]any{ "null": nil, }), + RepeatedInt32: []int32{1, 2, 3}, + RepeatedInt64: []int64{1, 2, 3}, + RepeatedUint32: []uint32{1, 2, 3}, + RepeatedUint64: []uint64{1, 2, 3}, + RepeatedSint32: []int32{1, 2, 3}, + RepeatedSint64: []int64{1, 2, 3}, + RepeatedFixed32: []uint32{1, 2, 3}, + RepeatedFixed64: []uint64{1, 2, 3}, + RepeatedSfixed32: []int32{1, 2, 3}, + RepeatedSfixed64: []int64{1, 2, 3}, + RepeatedFloat: []float32{1.0, 2.0, 3.0}, + RepeatedDouble: []float64{1.0, 2.0, 3.0}, + RepeatedBool: []bool{true, false, true}, + RepeatedString: []string{"a", "b", "c"}, + RepeatedBytes: [][]byte{{1, 2, 3}, {4, 5, 6}}, + RepeatedNestedMessage: []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}}, + RepeatedNestedEnum: []proto3pb.TestAllTypes_NestedEnum{proto3pb.TestAllTypes_BAR, proto3pb.TestAllTypes_BAZ}, + RepeatedStringPiece: []string{"a", "b", "c"}, + RepeatedCord: []string{"a", "b", "c"}, + RepeatedLazyMessage: []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}}, + MapStringString: map[string]string{"a": "1", "b": "2"}, + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: {Payload: &proto3pb.TestAllTypes{SingleUint64: 12}}, + }, + ImportedEnums: []proto3pb.ImportedGlobalEnum{proto3pb.ImportedGlobalEnum_IMPORT_FOO}, } msgName := string(msg.ProtoReflect().Descriptor().FullName()) _, err := pbdb.RegisterMessage(msg) @@ -182,11 +207,36 @@ func TestFieldDescriptionGetFrom(t *testing.T) { "single_nested_message": &proto3pb.TestAllTypes_NestedMessage{ Bb: 123, }, - "standalone_enum": int64(1), + "standalone_enum": proto3pb.TestAllTypes_BAR, "single_value": "hello world", "single_struct": jsonStruct(t, map[string]any{ "null": nil, }), + "repeated_int32": []int32{1, 2, 3}, + "repeated_int64": []int64{1, 2, 3}, + "repeated_uint32": []uint32{1, 2, 3}, + "repeated_uint64": []uint64{1, 2, 3}, + "repeated_sint32": []int32{1, 2, 3}, + "repeated_sint64": []int64{1, 2, 3}, + "repeated_fixed32": []uint32{1, 2, 3}, + "repeated_fixed64": []uint64{1, 2, 3}, + "repeated_sfixed32": []int32{1, 2, 3}, + "repeated_sfixed64": []int64{1, 2, 3}, + "repeated_float": []float32{1.0, 2.0, 3.0}, + "repeated_double": []float64{1.0, 2.0, 3.0}, + "repeated_bool": []bool{true, false, true}, + "repeated_string": []string{"a", "b", "c"}, + "repeated_bytes": [][]byte{{1, 2, 3}, {4, 5, 6}}, + "repeated_nested_message": []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}}, + "repeated_nested_enum": []proto3pb.TestAllTypes_NestedEnum{proto3pb.TestAllTypes_BAR, proto3pb.TestAllTypes_BAZ}, + "repeated_string_piece": []string{"a", "b", "c"}, + "repeated_cord": []string{"a", "b", "c"}, + "repeated_lazy_message": []*proto3pb.TestAllTypes_NestedMessage{{Bb: 123}, {Bb: 456}}, + "map_string_string": map[string]string{"a": "1", "b": "2"}, + "map_int64_nested_type": map[int64]*proto3pb.NestedTestAllTypes{ + 1: {Payload: &proto3pb.TestAllTypes{SingleUint64: 12}}, + }, + "imported_enums": []proto3pb.ImportedGlobalEnum{proto3pb.ImportedGlobalEnum_IMPORT_FOO}, } for field, want := range expected { f, found := td.FieldByName(field) @@ -200,11 +250,23 @@ func TestFieldDescriptionGetFrom(t *testing.T) { switch g := got.(type) { case proto.Message: if !proto.Equal(g, want.(proto.Message)) { - t.Errorf("got field %s value %v, wanted %v", field, g, want) + t.Errorf("got field %s type %T, value %v, wanted type %T, value %v", field, g, g, want, want) + } + case []*proto3pb.TestAllTypes_NestedMessage: + for i, gv := range g { + if !proto.Equal(gv, want.([]*proto3pb.TestAllTypes_NestedMessage)[i]) { + t.Errorf("got field %s[%d] type %T, value %v, wanted type %T, value %v", field, i, gv, gv, want.([]*proto3pb.TestAllTypes_NestedMessage)[i], want.([]*proto3pb.TestAllTypes_NestedMessage)[i]) + } + } + case map[int64]*proto3pb.NestedTestAllTypes: + for k, gv := range g { + if !proto.Equal(gv, want.(map[int64]*proto3pb.NestedTestAllTypes)[k]) { + t.Errorf("got field %s[%d] type %T, value %v, wanted type %T, value %v", field, k, gv, gv, want.(map[int64]*proto3pb.NestedTestAllTypes)[k], want.(map[int64]*proto3pb.NestedTestAllTypes)[k]) + } } default: if !reflect.DeepEqual(g, want) { - t.Errorf("got field %s value %v, wanted %v", field, g, want) + t.Errorf("got field %s type %T, value %v, wanted type %T, value %v", field, g, g, want, want) } } }