Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
specialTypeIface
specialTypeTag
specialTypeTime
specialTypeJSONUnmarshalerIface
)

type typeInfo struct {
Expand Down Expand Up @@ -75,6 +76,8 @@ func newTypeInfo(t reflect.Type) *typeInfo {
tInfo.spclType = specialTypeUnexportedUnmarshalerIface
} else if reflect.PointerTo(t).Implements(typeUnmarshaler) {
tInfo.spclType = specialTypeUnmarshalerIface
} else if reflect.PointerTo(t).Implements(typeJSONUnmarshaler) {
tInfo.spclType = specialTypeJSONUnmarshalerIface
}

switch k {
Expand Down
9 changes: 9 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"fmt"
"io"
"strconv"
)

Expand Down Expand Up @@ -180,3 +181,11 @@ func validBuiltinTag(tagNum uint64, contentHead byte) error {

return nil
}

// Transcoder is a scheme for transcoding a single CBOR encoded data item to or from a different
// data format.
type Transcoder interface {
// Transcode reads the data item in its source format from a Reader and writes a
// corresponding representation in its destination format to a Writer.
Transcode(dst io.Writer, src io.Reader) error
}
210 changes: 128 additions & 82 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package cbor

import (
"bytes"
"encoding"
"encoding/base64"
"encoding/binary"
Expand Down Expand Up @@ -906,6 +907,11 @@ type DecOptions struct {
// TextUnmarshaler specifies how to decode into types that implement
// encoding.TextUnmarshaler.
TextUnmarshaler TextUnmarshalerMode

// JSONUnmarshalerTranscoder sets the transcoding scheme used to unmarshal types that
// implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding
// behavior is not influenced by whether or not a type implements json.Unmarshaler.
JSONUnmarshalerTranscoder Transcoder
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -1123,33 +1129,34 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
simpleValues: simpleValues,
nanDec: opts.NaN,
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteStringExpectedFormat: opts.ByteStringExpectedFormat,
bignumTag: opts.BignumTag,
binaryUnmarshaler: opts.BinaryUnmarshaler,
textUnmarshaler: opts.TextUnmarshaler,
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
simpleValues: simpleValues,
nanDec: opts.NaN,
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteStringExpectedFormat: opts.ByteStringExpectedFormat,
bignumTag: opts.BignumTag,
binaryUnmarshaler: opts.BinaryUnmarshaler,
textUnmarshaler: opts.TextUnmarshaler,
jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder,
}

return &dm, nil
Expand Down Expand Up @@ -1202,34 +1209,35 @@ type DecMode interface {
}

type decMode struct {
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
simpleValues *SimpleValueRegistry
nanDec NaNMode
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteStringExpectedFormat ByteStringExpectedFormatMode
bignumTag BignumTagMode
binaryUnmarshaler BinaryUnmarshalerMode
textUnmarshaler TextUnmarshalerMode
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
simpleValues *SimpleValueRegistry
nanDec NaNMode
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteStringExpectedFormat ByteStringExpectedFormatMode
bignumTag BignumTagMode
binaryUnmarshaler BinaryUnmarshalerMode
textUnmarshaler TextUnmarshalerMode
jsonUnmarshalerTranscoder Transcoder
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -1244,33 +1252,34 @@ func (dm *decMode) DecOptions() DecOptions {
}

return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
DefaultMapType: dm.defaultMapType,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
SimpleValues: simpleValues,
NaN: dm.nanDec,
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteStringExpectedFormat: dm.byteStringExpectedFormat,
BignumTag: dm.bignumTag,
BinaryUnmarshaler: dm.binaryUnmarshaler,
TextUnmarshaler: dm.textUnmarshaler,
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
DefaultMapType: dm.defaultMapType,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
SimpleValues: simpleValues,
NaN: dm.nanDec,
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteStringExpectedFormat: dm.byteStringExpectedFormat,
BignumTag: dm.bignumTag,
BinaryUnmarshaler: dm.binaryUnmarshaler,
TextUnmarshaler: dm.textUnmarshaler,
JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder,
}
}

Expand Down Expand Up @@ -1497,6 +1506,14 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin

case specialTypeUnexportedUnmarshalerIface:
return d.parseToUnexportedUnmarshaler(v)

case specialTypeJSONUnmarshalerIface:
// This special type implies that the type does not also implement
// cbor.Umarshaler.
if d.dm.jsonUnmarshalerTranscoder == nil {
break
}
return d.parseToJSONUnmarshaler(v)
}
}

Expand Down Expand Up @@ -1862,6 +1879,32 @@ func (d *decoder) parseToUnexportedUnmarshaler(v reflect.Value) error {
return errors.New("cbor: failed to assert " + v.Type().String() + " as cbor.unmarshaler")
}

// parseToJSONUnmarshaler parses CBOR data to be transcoded to JSON and passed to the value's
// implementation of the json.Unmarshaler interface. It assumes data is well-formed, and does not
// perform bounds checking.
func (d *decoder) parseToJSONUnmarshaler(v reflect.Value) error {
if d.nextCBORNil() && v.Kind() == reflect.Pointer && v.IsNil() {
d.skip()
return nil
}

if v.Kind() != reflect.Pointer && v.CanAddr() {
v = v.Addr()
}
if u, ok := v.Interface().(jsonUnmarshaler); ok {
start := d.off
d.skip()
e := getEncodeBuffer()
defer putEncodeBuffer(e)
if err := d.dm.jsonUnmarshalerTranscoder.Transcode(e, bytes.NewReader(d.data[start:d.off])); err != nil {
return &TranscodeError{err: err, rtype: v.Type(), sourceFormat: "cbor", targetFormat: "json"}
}
return u.UnmarshalJSON(e.Bytes())
}
d.skip()
return errors.New("cbor: failed to assert " + v.Type().String() + " as json.Unmarshaler")
}

// parse parses CBOR data and returns value in default Go type.
// It assumes data is well-formed, and does not perform bounds checking.
func (d *decoder) parse(skipSelfDescribedTag bool) (any, error) { //nolint:gocyclo
Expand Down Expand Up @@ -3018,6 +3061,8 @@ func (d *decoder) nextCBORNil() bool {
return d.data[d.off] == 0xf6 || d.data[d.off] == 0xf7
}

type jsonUnmarshaler interface{ UnmarshalJSON([]byte) error }

var (
typeIntf = reflect.TypeOf([]any(nil)).Elem()
typeTime = reflect.TypeOf(time.Time{})
Expand All @@ -3026,6 +3071,7 @@ var (
typeUnexportedUnmarshaler = reflect.TypeOf((*unmarshaler)(nil)).Elem()
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
typeTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
typeJSONUnmarshaler = reflect.TypeOf((*jsonUnmarshaler)(nil)).Elem()
typeString = reflect.TypeOf("")
typeByteSlice = reflect.TypeOf([]byte(nil))
)
Expand Down
Loading
Loading