From d8e0aa63e6022b050e9b256636950137f5a721bb Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Tue, 10 Jun 2025 14:02:27 -0400 Subject: [PATCH 1/2] Add optional encoding support for json.Marshaler. If the user provides a JSON-to-CBOR transcode function, a value whose type implements json.Marshaler and not cbor.Marshaler will be encoded by first calling its MarshalJSON method, then transcoding the result to CBOR. Signed-off-by: Ben Luddy --- common.go | 9 ++ encode.go | 125 +++++++++++++++--- encode_test.go | 253 +++++++++++++++++++++++++++++++++--- example_transcoding_test.go | 63 +++++++++ 4 files changed, 416 insertions(+), 34 deletions(-) create mode 100644 example_transcoding_test.go diff --git a/common.go b/common.go index ec038a49..9cf33cd2 100644 --- a/common.go +++ b/common.go @@ -5,6 +5,7 @@ package cbor import ( "fmt" + "io" "strconv" ) @@ -180,3 +181,11 @@ func validBuiltinTag(tagNum uint64, contentHead byte) error { return nil } + +// Transcoder is a scheme for transcoding a single CBOR encoded data item to or from a different +// data format. +type Transcoder interface { + // Transcode reads the data item in its source format from a Reader and writes a + // corresponding representation in its destination format to a Writer. + Transcode(dst io.Writer, src io.Reader) error +} diff --git a/encode.go b/encode.go index bf223147..c550617c 100644 --- a/encode.go +++ b/encode.go @@ -132,6 +132,20 @@ func (e *MarshalerError) Unwrap() error { return e.err } +type TranscodeError struct { + err error + rtype reflect.Type + sourceFormat, targetFormat string +} + +func (e TranscodeError) Error() string { + return "cbor: cannot transcode from " + e.sourceFormat + " to " + e.targetFormat + ": " + e.err.Error() +} + +func (e TranscodeError) Unwrap() error { + return e.err +} + // UnsupportedTypeError is returned by Marshal when attempting to encode value // of an unsupported type. type UnsupportedTypeError struct { @@ -588,6 +602,11 @@ type EncOptions struct { // TextMarshaler specifies how to encode types that implement encoding.TextMarshaler. TextMarshaler TextMarshalerMode + + // JSONMarshalerTranscoder sets the transcoding scheme used to marshal types that implement + // json.Marshaler but do not also implement cbor.Marshaler. If nil, encoding behavior is not + // influenced by whether or not a type implements json.Marshaler. + JSONMarshalerTranscoder Transcoder } // CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding, @@ -821,6 +840,7 @@ func (opts EncOptions) encMode() (*encMode, error) { //nolint:gocritic // ignore byteArray: opts.ByteArray, binaryMarshaler: opts.BinaryMarshaler, textMarshaler: opts.TextMarshaler, + jsonMarshalerTranscoder: opts.JSONMarshalerTranscoder, } return &em, nil } @@ -867,6 +887,7 @@ type encMode struct { byteArray ByteArrayMode binaryMarshaler BinaryMarshalerMode textMarshaler TextMarshalerMode + jsonMarshalerTranscoder Transcoder } var defaultEncMode, _ = EncOptions{}.encMode() @@ -943,23 +964,24 @@ func getMarshalerDecMode(indefLength IndefLengthMode, tagsMd TagsMode) *decMode // EncOptions returns user specified options used to create this EncMode. func (em *encMode) EncOptions() EncOptions { return EncOptions{ - Sort: em.sort, - ShortestFloat: em.shortestFloat, - NaNConvert: em.nanConvert, - InfConvert: em.infConvert, - BigIntConvert: em.bigIntConvert, - Time: em.time, - TimeTag: em.timeTag, - IndefLength: em.indefLength, - NilContainers: em.nilContainers, - TagsMd: em.tagsMd, - OmitEmpty: em.omitEmpty, - String: em.stringType, - FieldName: em.fieldName, - ByteSliceLaterFormat: em.byteSliceLaterFormat, - ByteArray: em.byteArray, - BinaryMarshaler: em.binaryMarshaler, - TextMarshaler: em.textMarshaler, + Sort: em.sort, + ShortestFloat: em.shortestFloat, + NaNConvert: em.nanConvert, + InfConvert: em.infConvert, + BigIntConvert: em.bigIntConvert, + Time: em.time, + TimeTag: em.timeTag, + IndefLength: em.indefLength, + NilContainers: em.nilContainers, + TagsMd: em.tagsMd, + OmitEmpty: em.omitEmpty, + String: em.stringType, + FieldName: em.fieldName, + ByteSliceLaterFormat: em.byteSliceLaterFormat, + ByteArray: em.byteArray, + BinaryMarshaler: em.binaryMarshaler, + TextMarshaler: em.textMarshaler, + JSONMarshalerTranscoder: em.jsonMarshalerTranscoder, } } @@ -1779,6 +1801,59 @@ func (tme textMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, err return len(data) == 0, nil } +type jsonMarshalerEncoder struct { + alternateEncode encodeFunc + alternateIsEmpty isEmptyFunc +} + +func (jme jsonMarshalerEncoder) encode(e *bytes.Buffer, em *encMode, v reflect.Value) error { + if em.jsonMarshalerTranscoder == nil { + return jme.alternateEncode(e, em, v) + } + + vt := v.Type() + m, ok := v.Interface().(jsonMarshaler) + if !ok { + pv := reflect.New(vt) + pv.Elem().Set(v) + m = pv.Interface().(jsonMarshaler) + } + + json, err := m.MarshalJSON() + if err != nil { + return err + } + + offset := e.Len() + + if b := em.encTagBytes(vt); b != nil { + e.Write(b) + } + + if err := em.jsonMarshalerTranscoder.Transcode(e, bytes.NewReader(json)); err != nil { + return &TranscodeError{err: err, rtype: vt, sourceFormat: "json", targetFormat: "cbor"} + } + + // Validate that the transcode function has written exactly one well-formed data item. + d := decoder{data: e.Bytes()[offset:], dm: getMarshalerDecMode(em.indefLength, em.tagsMd)} + if err := d.wellformed(false, true); err != nil { + e.Truncate(offset) + return &TranscodeError{err: err, rtype: vt, sourceFormat: "json", targetFormat: "cbor"} + } + + return nil +} + +func (jme jsonMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, error) { + if em.jsonMarshalerTranscoder == nil { + return jme.alternateIsEmpty(em, v) + } + + // As with types implementing cbor.Marshaler, transcoded json.Marshaler values always encode + // as exactly one complete CBOR data item. + return false, nil +} + func encodeMarshalerType(e *bytes.Buffer, em *encMode, v reflect.Value) error { if em.tagsMd == TagsForbidden && v.Type() == typeRawTag { return errors.New("cbor: cannot encode cbor.RawTag when TagsMd is TagsForbidden") @@ -1882,10 +1957,13 @@ func encodeHead(e *bytes.Buffer, t byte, n uint64) int { return headSize } +type jsonMarshaler interface{ MarshalJSON() ([]byte, error) } + var ( typeMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() typeBinaryMarshaler = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem() typeTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + typeJSONMarshaler = reflect.TypeOf((*jsonMarshaler)(nil)).Elem() typeRawMessage = reflect.TypeOf(RawMessage(nil)) typeByteString = reflect.TypeOf(ByteString("")) ) @@ -1939,6 +2017,19 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc, izf ief = tme.isEmpty }() } + if reflect.PointerTo(t).Implements(typeJSONMarshaler) { + defer func() { + // capture encoding method used for modes that don't support transcoding + // from types that implement json.Marshaler. + jme := jsonMarshalerEncoder{ + alternateEncode: ef, + alternateIsEmpty: ief, + } + ef = jme.encode + ief = jme.isEmpty + }() + } + switch k { case reflect.Bool: return encodeBool, isEmptyBool, getIsZeroFunc(t) diff --git a/encode_test.go b/encode_test.go index 860a7db5..e9dbfdb1 100644 --- a/encode_test.go +++ b/encode_test.go @@ -4503,25 +4503,32 @@ func TestEncOptionsTagsForbidden(t *testing.T) { } } +type stubTranscoder struct{} + +func (stubTranscoder) Transcode(io.Writer, io.Reader) error { + return nil +} + func TestEncOptions(t *testing.T) { opts1 := EncOptions{ - Sort: SortBytewiseLexical, - ShortestFloat: ShortestFloat16, - NaNConvert: NaNConvertPreserveSignal, - InfConvert: InfConvertNone, - BigIntConvert: BigIntConvertNone, - Time: TimeRFC3339Nano, - TimeTag: EncTagRequired, - IndefLength: IndefLengthForbidden, - NilContainers: NilContainerAsEmpty, - TagsMd: TagsAllowed, - OmitEmpty: OmitEmptyGoValue, - String: StringToByteString, - FieldName: FieldNameToByteString, - ByteSliceLaterFormat: ByteSliceLaterFormatBase16, - ByteArray: ByteArrayToArray, - BinaryMarshaler: BinaryMarshalerNone, - TextMarshaler: TextMarshalerTextString, + Sort: SortBytewiseLexical, + ShortestFloat: ShortestFloat16, + NaNConvert: NaNConvertPreserveSignal, + InfConvert: InfConvertNone, + BigIntConvert: BigIntConvertNone, + Time: TimeRFC3339Nano, + TimeTag: EncTagRequired, + IndefLength: IndefLengthForbidden, + NilContainers: NilContainerAsEmpty, + TagsMd: TagsAllowed, + OmitEmpty: OmitEmptyGoValue, + String: StringToByteString, + FieldName: FieldNameToByteString, + ByteSliceLaterFormat: ByteSliceLaterFormatBase16, + ByteArray: ByteArrayToArray, + BinaryMarshaler: BinaryMarshalerNone, + TextMarshaler: TextMarshalerTextString, + JSONMarshalerTranscoder: stubTranscoder{}, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -5754,3 +5761,215 @@ func TestTextMarshalerModeError(t *testing.T) { }) } } + +type stubJSONMarshaler struct { + JSON string + Error error +} + +func (m stubJSONMarshaler) MarshalJSON() ([]byte, error) { + return []byte(m.JSON), m.Error +} + +type stubJSONMarshalerPointerReceiver struct { + JSON string +} + +func (m *stubJSONMarshalerPointerReceiver) MarshalJSON() ([]byte, error) { + return []byte(m.JSON), nil +} + +type transcodeFunc func(io.Writer, io.Reader) error + +func (f transcodeFunc) Transcode(w io.Writer, r io.Reader) error { + return f(w, r) +} + +func TestJSONMarshalerTranscoderNil(t *testing.T) { + enc, err := EncOptions{}.EncMode() + if err != nil { + t.Fatal(err) + } + + { + // default encode behavior of underlying type + got, err := enc.Marshal(&stubJSONMarshalerPointerReceiver{JSON: "z"}) + if err != nil { + t.Fatal(err) + } + + want := []byte{0xa1, 0x64, 'J', 'S', 'O', 'N', 0x61, 'z'} + if !bytes.Equal(got, want) { + t.Errorf("want 0x%x, got 0x%x", want, got) + } + } + + { + // default empty condition of underlying type + got, err := enc.Marshal(struct { + M stubJSONMarshalerPointerReceiver `cbor:"m,omitempty"` + }{}) + if err != nil { + t.Fatal(err) + } + + want := []byte{0xa1, 0x61, 'm', 0xa1, 0x64, 'J', 'S', 'O', 'N', 0x60} + if !bytes.Equal(got, want) { + t.Errorf("want 0x%x, got 0x%x", want, got) + } + } + +} + +func TestJSONMarshalerTranscoder(t *testing.T) { + testTags := NewTagSet() + if err := testTags.Add(TagOptions{EncTag: EncTagRequired}, reflect.TypeOf(stubJSONMarshaler{}), 9999); err != nil { + t.Fatal(err) + } + + for _, tc := range []struct { + name string + value any + tags TagSet + + transcodeInput []byte + transcodeOutput []byte + transcodeError error + + wantCborData []byte + wantErrorMsg string + }{ + { + name: "value-receiver marshaler", + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61, 'a'}, + wantCborData: []byte{0x61, 'a'}, + }, + { + name: "transcoder returns non-nil error", + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeError: errors.New("test"), + wantErrorMsg: TranscodeError{ + err: errors.New("test"), + rtype: reflect.TypeOf(stubJSONMarshaler{}), + sourceFormat: "json", + targetFormat: "cbor", + }.Error(), + }, + { + name: "transcoder produces invalid cbor", + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0xff}, + wantErrorMsg: TranscodeError{ + err: errors.New(`cbor: unexpected "break" code`), + rtype: reflect.TypeOf(stubJSONMarshaler{}), + sourceFormat: "json", + targetFormat: "cbor", + }.Error(), + }, + { + name: "transcoder produces short cbor", + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61}, + wantErrorMsg: TranscodeError{ + err: io.ErrUnexpectedEOF, + rtype: reflect.TypeOf(stubJSONMarshaler{}), + sourceFormat: "json", + targetFormat: "cbor", + }.Error(), + }, + { + name: "transcoder produces extraneous cbor", + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61, 'a', 0x61, 'b'}, + wantErrorMsg: TranscodeError{ + err: &ExtraneousDataError{numOfBytes: 2, index: 2}, + rtype: reflect.TypeOf(stubJSONMarshaler{}), + sourceFormat: "json", + targetFormat: "cbor", + }.Error(), + }, + { + name: "marshaler returns non-nil error", + value: stubJSONMarshaler{Error: errors.New("test")}, + wantErrorMsg: "test", + }, + { + name: "value-receiver marshaler with registered tag", + tags: testTags, + value: stubJSONMarshaler{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61, 'a'}, // "a" + wantCborData: []byte{0xd9, 0x27, 0x0f, 0x61, 'a'}, // 9999("a") + }, + { + name: "pointer-receiver marshaler", + value: stubJSONMarshalerPointerReceiver{JSON: `"a"`}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61, 'a'}, + wantCborData: []byte{0x61, 'a'}, + }, + { + name: "never omitempty", + value: struct { + M stubJSONMarshaler `cbor:"m,omitempty"` + }{M: stubJSONMarshaler{JSON: `"a"`}}, + transcodeInput: []byte(`"a"`), + transcodeOutput: []byte{0x61, 'a'}, + wantCborData: []byte{0xa1, 0x61, 'm', 0x61, 'a'}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + opts := EncOptions{ + JSONMarshalerTranscoder: transcodeFunc(func(w io.Writer, r io.Reader) error { + source, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if got := string(source); got != string(tc.transcodeInput) { + t.Errorf("transcoder got input %q, want %q", got, string(tc.transcodeInput)) + } + + if tc.transcodeError != nil { + return tc.transcodeError + } + + _, err = w.Write(tc.transcodeOutput) + return err + }), + } + var ( + enc EncMode + err error + ) + if tc.tags != nil { + enc, err = opts.EncModeWithTags(tc.tags) + } else { + enc, err = opts.EncMode() + } + if err != nil { + t.Fatal(err) + } + + b, err := enc.Marshal(tc.value) + if tc.wantErrorMsg != "" { + if err == nil { + t.Errorf("Marshal(%v) didn't return an error, want error %q", tc.value, tc.wantErrorMsg) + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("Marshal(%v) returned error %q, want %q", tc.value, gotErrorMsg, tc.wantErrorMsg) + } + } else { + if err != nil { + t.Errorf("Marshal(%v) returned non-nil error %v", tc.value, err) + } else if !bytes.Equal(b, tc.wantCborData) { + t.Errorf("Marshal(%v) = 0x%x, want 0x%x", tc.value, b, tc.wantCborData) + } + } + }) + } +} diff --git a/example_transcoding_test.go b/example_transcoding_test.go new file mode 100644 index 00000000..a3551338 --- /dev/null +++ b/example_transcoding_test.go @@ -0,0 +1,63 @@ +// Copyright (c) Faye Amacker. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +package cbor_test + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/fxamacker/cbor/v2" +) + +type TranscoderFunc func(io.Writer, io.Reader) error + +func (f TranscoderFunc) Transcode(w io.Writer, r io.Reader) error { + return f(w, r) +} + +func ExampleTranscoder_fromJSON() { + enc, _ := cbor.EncOptions{ + JSONMarshalerTranscoder: TranscoderFunc(func(w io.Writer, r io.Reader) error { + d := json.NewDecoder(r) + + for { + token, err := d.Token() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + switch token { + case json.Delim('['): + if _, err := w.Write([]byte{0x9f}); err != nil { + return err + } + case json.Delim('{'): + if _, err := w.Write([]byte{0xbf}); err != nil { + return err + } + case json.Delim(']'), json.Delim('}'): + if _, err := w.Write([]byte{0xff}); err != nil { + return err + } + default: + b, err := cbor.Marshal(token) + if err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + } + } + }), + }.EncMode() + + got, _ := enc.Marshal(json.RawMessage(`{"a": [true, "z", {"y": 3.14}], "b": {"c": null}}`)) + diag, _ := cbor.Diagnose(got) + fmt.Println(diag) + // Output: {_ "a": [_ true, "z", {_ "y": 3.14}], "b": {_ "c": null}} +} From ba129ebd429db2a25caa198ffd9b92116763b3c8 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Tue, 10 Jun 2025 15:42:32 -0400 Subject: [PATCH 2/2] Add optional support for json.Unmarshaler. Users can provide a function to transcode an encoded CBOR data item to JSON. If provided, then unmarshaling into a value whose type implements json.Unmarshaler, but not cbor.Unmarshaler, will first transcode the input bytes to JSON and then invoke UnmarshalJSON on the destination value. Signed-off-by: Ben Luddy --- cache.go | 3 + decode.go | 210 ++++++++++++++++++++++-------------- decode_test.go | 137 ++++++++++++++++++----- example_transcoding_test.go | 22 ++++ 4 files changed, 263 insertions(+), 109 deletions(-) diff --git a/cache.go b/cache.go index 0d96b988..5051f110 100644 --- a/cache.go +++ b/cache.go @@ -37,6 +37,7 @@ const ( specialTypeIface specialTypeTag specialTypeTime + specialTypeJSONUnmarshalerIface ) type typeInfo struct { @@ -75,6 +76,8 @@ func newTypeInfo(t reflect.Type) *typeInfo { tInfo.spclType = specialTypeUnexportedUnmarshalerIface } else if reflect.PointerTo(t).Implements(typeUnmarshaler) { tInfo.spclType = specialTypeUnmarshalerIface + } else if reflect.PointerTo(t).Implements(typeJSONUnmarshaler) { + tInfo.spclType = specialTypeJSONUnmarshalerIface } switch k { diff --git a/decode.go b/decode.go index 170e1253..ba12f2f3 100644 --- a/decode.go +++ b/decode.go @@ -4,6 +4,7 @@ package cbor import ( + "bytes" "encoding" "encoding/base64" "encoding/binary" @@ -906,6 +907,11 @@ type DecOptions struct { // TextUnmarshaler specifies how to decode into types that implement // encoding.TextUnmarshaler. TextUnmarshaler TextUnmarshalerMode + + // JSONUnmarshalerTranscoder sets the transcoding scheme used to unmarshal types that + // implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding + // behavior is not influenced by whether or not a type implements json.Unmarshaler. + JSONUnmarshalerTranscoder Transcoder } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -1123,33 +1129,34 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore } dm := decMode{ - dupMapKey: opts.DupMapKey, - timeTag: opts.TimeTag, - maxNestedLevels: opts.MaxNestedLevels, - maxArrayElements: opts.MaxArrayElements, - maxMapPairs: opts.MaxMapPairs, - indefLength: opts.IndefLength, - tagsMd: opts.TagsMd, - intDec: opts.IntDec, - mapKeyByteString: opts.MapKeyByteString, - extraReturnErrors: opts.ExtraReturnErrors, - defaultMapType: opts.DefaultMapType, - utf8: opts.UTF8, - fieldNameMatching: opts.FieldNameMatching, - bigIntDec: opts.BigIntDec, - defaultByteStringType: opts.DefaultByteStringType, - byteStringToString: opts.ByteStringToString, - fieldNameByteString: opts.FieldNameByteString, - unrecognizedTagToAny: opts.UnrecognizedTagToAny, - timeTagToAny: opts.TimeTagToAny, - simpleValues: simpleValues, - nanDec: opts.NaN, - infDec: opts.Inf, - byteStringToTime: opts.ByteStringToTime, - byteStringExpectedFormat: opts.ByteStringExpectedFormat, - bignumTag: opts.BignumTag, - binaryUnmarshaler: opts.BinaryUnmarshaler, - textUnmarshaler: opts.TextUnmarshaler, + dupMapKey: opts.DupMapKey, + timeTag: opts.TimeTag, + maxNestedLevels: opts.MaxNestedLevels, + maxArrayElements: opts.MaxArrayElements, + maxMapPairs: opts.MaxMapPairs, + indefLength: opts.IndefLength, + tagsMd: opts.TagsMd, + intDec: opts.IntDec, + mapKeyByteString: opts.MapKeyByteString, + extraReturnErrors: opts.ExtraReturnErrors, + defaultMapType: opts.DefaultMapType, + utf8: opts.UTF8, + fieldNameMatching: opts.FieldNameMatching, + bigIntDec: opts.BigIntDec, + defaultByteStringType: opts.DefaultByteStringType, + byteStringToString: opts.ByteStringToString, + fieldNameByteString: opts.FieldNameByteString, + unrecognizedTagToAny: opts.UnrecognizedTagToAny, + timeTagToAny: opts.TimeTagToAny, + simpleValues: simpleValues, + nanDec: opts.NaN, + infDec: opts.Inf, + byteStringToTime: opts.ByteStringToTime, + byteStringExpectedFormat: opts.ByteStringExpectedFormat, + bignumTag: opts.BignumTag, + binaryUnmarshaler: opts.BinaryUnmarshaler, + textUnmarshaler: opts.TextUnmarshaler, + jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder, } return &dm, nil @@ -1202,34 +1209,35 @@ type DecMode interface { } type decMode struct { - tags tagProvider - dupMapKey DupMapKeyMode - timeTag DecTagMode - maxNestedLevels int - maxArrayElements int - maxMapPairs int - indefLength IndefLengthMode - tagsMd TagsMode - intDec IntDecMode - mapKeyByteString MapKeyByteStringMode - extraReturnErrors ExtraDecErrorCond - defaultMapType reflect.Type - utf8 UTF8Mode - fieldNameMatching FieldNameMatchingMode - bigIntDec BigIntDecMode - defaultByteStringType reflect.Type - byteStringToString ByteStringToStringMode - fieldNameByteString FieldNameByteStringMode - unrecognizedTagToAny UnrecognizedTagToAnyMode - timeTagToAny TimeTagToAnyMode - simpleValues *SimpleValueRegistry - nanDec NaNMode - infDec InfMode - byteStringToTime ByteStringToTimeMode - byteStringExpectedFormat ByteStringExpectedFormatMode - bignumTag BignumTagMode - binaryUnmarshaler BinaryUnmarshalerMode - textUnmarshaler TextUnmarshalerMode + tags tagProvider + dupMapKey DupMapKeyMode + timeTag DecTagMode + maxNestedLevels int + maxArrayElements int + maxMapPairs int + indefLength IndefLengthMode + tagsMd TagsMode + intDec IntDecMode + mapKeyByteString MapKeyByteStringMode + extraReturnErrors ExtraDecErrorCond + defaultMapType reflect.Type + utf8 UTF8Mode + fieldNameMatching FieldNameMatchingMode + bigIntDec BigIntDecMode + defaultByteStringType reflect.Type + byteStringToString ByteStringToStringMode + fieldNameByteString FieldNameByteStringMode + unrecognizedTagToAny UnrecognizedTagToAnyMode + timeTagToAny TimeTagToAnyMode + simpleValues *SimpleValueRegistry + nanDec NaNMode + infDec InfMode + byteStringToTime ByteStringToTimeMode + byteStringExpectedFormat ByteStringExpectedFormatMode + bignumTag BignumTagMode + binaryUnmarshaler BinaryUnmarshalerMode + textUnmarshaler TextUnmarshalerMode + jsonUnmarshalerTranscoder Transcoder } var defaultDecMode, _ = DecOptions{}.decMode() @@ -1244,33 +1252,34 @@ func (dm *decMode) DecOptions() DecOptions { } return DecOptions{ - DupMapKey: dm.dupMapKey, - TimeTag: dm.timeTag, - MaxNestedLevels: dm.maxNestedLevels, - MaxArrayElements: dm.maxArrayElements, - MaxMapPairs: dm.maxMapPairs, - IndefLength: dm.indefLength, - TagsMd: dm.tagsMd, - IntDec: dm.intDec, - MapKeyByteString: dm.mapKeyByteString, - ExtraReturnErrors: dm.extraReturnErrors, - DefaultMapType: dm.defaultMapType, - UTF8: dm.utf8, - FieldNameMatching: dm.fieldNameMatching, - BigIntDec: dm.bigIntDec, - DefaultByteStringType: dm.defaultByteStringType, - ByteStringToString: dm.byteStringToString, - FieldNameByteString: dm.fieldNameByteString, - UnrecognizedTagToAny: dm.unrecognizedTagToAny, - TimeTagToAny: dm.timeTagToAny, - SimpleValues: simpleValues, - NaN: dm.nanDec, - Inf: dm.infDec, - ByteStringToTime: dm.byteStringToTime, - ByteStringExpectedFormat: dm.byteStringExpectedFormat, - BignumTag: dm.bignumTag, - BinaryUnmarshaler: dm.binaryUnmarshaler, - TextUnmarshaler: dm.textUnmarshaler, + DupMapKey: dm.dupMapKey, + TimeTag: dm.timeTag, + MaxNestedLevels: dm.maxNestedLevels, + MaxArrayElements: dm.maxArrayElements, + MaxMapPairs: dm.maxMapPairs, + IndefLength: dm.indefLength, + TagsMd: dm.tagsMd, + IntDec: dm.intDec, + MapKeyByteString: dm.mapKeyByteString, + ExtraReturnErrors: dm.extraReturnErrors, + DefaultMapType: dm.defaultMapType, + UTF8: dm.utf8, + FieldNameMatching: dm.fieldNameMatching, + BigIntDec: dm.bigIntDec, + DefaultByteStringType: dm.defaultByteStringType, + ByteStringToString: dm.byteStringToString, + FieldNameByteString: dm.fieldNameByteString, + UnrecognizedTagToAny: dm.unrecognizedTagToAny, + TimeTagToAny: dm.timeTagToAny, + SimpleValues: simpleValues, + NaN: dm.nanDec, + Inf: dm.infDec, + ByteStringToTime: dm.byteStringToTime, + ByteStringExpectedFormat: dm.byteStringExpectedFormat, + BignumTag: dm.bignumTag, + BinaryUnmarshaler: dm.binaryUnmarshaler, + TextUnmarshaler: dm.textUnmarshaler, + JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder, } } @@ -1497,6 +1506,14 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin case specialTypeUnexportedUnmarshalerIface: return d.parseToUnexportedUnmarshaler(v) + + case specialTypeJSONUnmarshalerIface: + // This special type implies that the type does not also implement + // cbor.Umarshaler. + if d.dm.jsonUnmarshalerTranscoder == nil { + break + } + return d.parseToJSONUnmarshaler(v) } } @@ -1862,6 +1879,32 @@ func (d *decoder) parseToUnexportedUnmarshaler(v reflect.Value) error { return errors.New("cbor: failed to assert " + v.Type().String() + " as cbor.unmarshaler") } +// parseToJSONUnmarshaler parses CBOR data to be transcoded to JSON and passed to the value's +// implementation of the json.Unmarshaler interface. It assumes data is well-formed, and does not +// perform bounds checking. +func (d *decoder) parseToJSONUnmarshaler(v reflect.Value) error { + if d.nextCBORNil() && v.Kind() == reflect.Pointer && v.IsNil() { + d.skip() + return nil + } + + if v.Kind() != reflect.Pointer && v.CanAddr() { + v = v.Addr() + } + if u, ok := v.Interface().(jsonUnmarshaler); ok { + start := d.off + d.skip() + e := getEncodeBuffer() + defer putEncodeBuffer(e) + if err := d.dm.jsonUnmarshalerTranscoder.Transcode(e, bytes.NewReader(d.data[start:d.off])); err != nil { + return &TranscodeError{err: err, rtype: v.Type(), sourceFormat: "cbor", targetFormat: "json"} + } + return u.UnmarshalJSON(e.Bytes()) + } + d.skip() + return errors.New("cbor: failed to assert " + v.Type().String() + " as json.Unmarshaler") +} + // parse parses CBOR data and returns value in default Go type. // It assumes data is well-formed, and does not perform bounds checking. func (d *decoder) parse(skipSelfDescribedTag bool) (any, error) { //nolint:gocyclo @@ -3018,6 +3061,8 @@ func (d *decoder) nextCBORNil() bool { return d.data[d.off] == 0xf6 || d.data[d.off] == 0xf7 } +type jsonUnmarshaler interface{ UnmarshalJSON([]byte) error } + var ( typeIntf = reflect.TypeOf([]any(nil)).Elem() typeTime = reflect.TypeOf(time.Time{}) @@ -3026,6 +3071,7 @@ var ( typeUnexportedUnmarshaler = reflect.TypeOf((*unmarshaler)(nil)).Elem() typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem() typeTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + typeJSONUnmarshaler = reflect.TypeOf((*jsonUnmarshaler)(nil)).Elem() typeString = reflect.TypeOf("") typeByteSlice = reflect.TypeOf([]byte(nil)) ) diff --git a/decode_test.go b/decode_test.go index 9df55d56..72bccc82 100644 --- a/decode_test.go +++ b/decode_test.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -4899,33 +4900,34 @@ func TestDecOptions(t *testing.T) { } opts1 := DecOptions{ - DupMapKey: DupMapKeyEnforcedAPF, - TimeTag: DecTagRequired, - MaxNestedLevels: 100, - MaxArrayElements: 102, - MaxMapPairs: 101, - IndefLength: IndefLengthForbidden, - TagsMd: TagsForbidden, - IntDec: IntDecConvertSigned, - MapKeyByteString: MapKeyByteStringForbidden, - ExtraReturnErrors: ExtraDecErrorUnknownField, - DefaultMapType: reflect.TypeOf(map[string]any(nil)), - UTF8: UTF8DecodeInvalid, - FieldNameMatching: FieldNameMatchingCaseSensitive, - BigIntDec: BigIntDecodePointer, - DefaultByteStringType: reflect.TypeOf(""), - ByteStringToString: ByteStringToStringAllowed, - FieldNameByteString: FieldNameByteStringAllowed, - UnrecognizedTagToAny: UnrecognizedTagContentToAny, - TimeTagToAny: TimeTagToRFC3339, - SimpleValues: simpleValues, - NaN: NaNDecodeForbidden, - Inf: InfDecodeForbidden, - ByteStringToTime: ByteStringToTimeAllowed, - ByteStringExpectedFormat: ByteStringExpectedBase64URL, - BignumTag: BignumTagForbidden, - BinaryUnmarshaler: BinaryUnmarshalerNone, - TextUnmarshaler: TextUnmarshalerTextString, + DupMapKey: DupMapKeyEnforcedAPF, + TimeTag: DecTagRequired, + MaxNestedLevels: 100, + MaxArrayElements: 102, + MaxMapPairs: 101, + IndefLength: IndefLengthForbidden, + TagsMd: TagsForbidden, + IntDec: IntDecConvertSigned, + MapKeyByteString: MapKeyByteStringForbidden, + ExtraReturnErrors: ExtraDecErrorUnknownField, + DefaultMapType: reflect.TypeOf(map[string]any(nil)), + UTF8: UTF8DecodeInvalid, + FieldNameMatching: FieldNameMatchingCaseSensitive, + BigIntDec: BigIntDecodePointer, + DefaultByteStringType: reflect.TypeOf(""), + ByteStringToString: ByteStringToStringAllowed, + FieldNameByteString: FieldNameByteStringAllowed, + UnrecognizedTagToAny: UnrecognizedTagContentToAny, + TimeTagToAny: TimeTagToRFC3339, + SimpleValues: simpleValues, + NaN: NaNDecodeForbidden, + Inf: InfDecodeForbidden, + ByteStringToTime: ByteStringToTimeAllowed, + ByteStringExpectedFormat: ByteStringExpectedBase64URL, + BignumTag: BignumTagForbidden, + BinaryUnmarshaler: BinaryUnmarshalerNone, + TextUnmarshaler: TextUnmarshalerTextString, + JSONUnmarshalerTranscoder: stubTranscoder{}, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -10235,3 +10237,84 @@ func TestTextUnmarshalerModeError(t *testing.T) { t.Errorf("want: %q, got: %q", want, got) } } + +func TestJSONUnmarshalerTranscoder(t *testing.T) { + for _, tc := range []struct { + name string + in []byte + + transcodeInput []byte + transcodeOutput []byte + transcodeError error + + want any + wantErrorMsg string + }{ + { + name: "successful transcode", + in: []byte{0xf5}, + + transcodeInput: []byte{0xf5}, + transcodeOutput: []byte("true"), + + want: json.RawMessage("true"), + }, + { + name: "transcode returns non-nil error", + in: []byte{0xf5}, + + transcodeInput: []byte{0xf5}, + transcodeError: errors.New("test"), + + want: json.RawMessage("true"), + wantErrorMsg: TranscodeError{ + err: errors.New("test"), + rtype: reflect.TypeOf((*json.RawMessage)(nil)), + sourceFormat: "cbor", + targetFormat: "json", + }.Error(), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dec, err := DecOptions{ + JSONUnmarshalerTranscoder: transcodeFunc(func(w io.Writer, r io.Reader) error { + source, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if got := string(source); got != string(tc.transcodeInput) { + t.Errorf("transcoder got input %q, want %q", got, string(tc.transcodeInput)) + } + + if tc.transcodeError != nil { + return tc.transcodeError + } + + _, err = w.Write(tc.transcodeOutput) + return err + + }), + }.DecMode() + if err != nil { + t.Fatal(err) + } + + gotrv := reflect.New(reflect.TypeOf(tc.want)) + err = dec.Unmarshal(tc.in, gotrv.Interface()) + if tc.wantErrorMsg != "" { + if err == nil { + t.Errorf("Unmarshal(0x%x) didn't return an error, want error %q", tc.in, tc.wantErrorMsg) + } else if gotErrorMsg := err.Error(); gotErrorMsg != tc.wantErrorMsg { + t.Errorf("Unmarshal(0x%x) returned error %q, want %q", tc.in, gotErrorMsg, tc.wantErrorMsg) + } + } else { + if err != nil { + t.Errorf("Unmarshal(0x%x) returned non-nil error %v", tc.in, err) + } else if got := gotrv.Elem().Interface(); !reflect.DeepEqual(tc.want, got) { + t.Errorf("Unmarshal(0x%x): %v, want %v", tc.in, got, tc.want) + } + } + + }) + } +} diff --git a/example_transcoding_test.go b/example_transcoding_test.go index a3551338..7491ce2c 100644 --- a/example_transcoding_test.go +++ b/example_transcoding_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "reflect" "github.com/fxamacker/cbor/v2" ) @@ -61,3 +62,24 @@ func ExampleTranscoder_fromJSON() { fmt.Println(diag) // Output: {_ "a": [_ true, "z", {_ "y": 3.14}], "b": {_ "c": null}} } + +func ExampleTranscoder_toJSON() { + var dec cbor.DecMode + dec, _ = cbor.DecOptions{ + DefaultMapType: reflect.TypeOf(map[string]any{}), + JSONUnmarshalerTranscoder: TranscoderFunc(func(w io.Writer, r io.Reader) error { + var tmp any + if err := dec.NewDecoder(r).Decode(&tmp); err != nil { + return err + } + return json.NewEncoder(w).Encode(tmp) + }), + }.DecMode() + + var got json.RawMessage + if err := dec.Unmarshal(cbor.RawMessage{0xa2, 0x61, 'a', 0x01, 0x61, 'b', 0x83, 0xf4, 0xf5, 0xf6}, &got); err != nil { + panic(err) + } + fmt.Println(string(got)) + // Output: {"a":1,"b":[false,true,null]} +}