Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,25 @@ func (tum TextUnmarshalerMode) valid() bool {
return tum >= 0 && tum < maxTextUnmarshalerMode
}

// FloatToIntMode specifies whether CBOR floating-point values can be decoded into Go integer types.
type FloatToIntMode int

const (
// FloatToIntForbidden disallows decoding CBOR floats into Go integer types.
FloatToIntForbidden FloatToIntMode = iota

// FloatToIntAllowExact permits decoding CBOR float values into Go integer
// types if the float value can be represented exactly in the destination
// type without loss of precision. NaN and infinity are never permitted.
FloatToIntAllowExact

maxFloatToIntMode
)

func (ftim FloatToIntMode) valid() bool {
return ftim >= 0 && ftim < maxFloatToIntMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -912,6 +931,10 @@ type DecOptions struct {
// implement json.Unmarshaler but do not also implement cbor.Unmarshaler. If nil, decoding
// behavior is not influenced by whether or not a type implements json.Unmarshaler.
JSONUnmarshalerTranscoder Transcoder

// FloatToInt specifies whether CBOR floating-point values can be decoded into Go integer
// types. By default, decoding a CBOR float into a Go integer type produces an error.
FloatToInt FloatToIntMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -1128,6 +1151,10 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
return nil, errors.New("cbor: invalid TextUnmarshaler " + strconv.Itoa(int(opts.TextUnmarshaler)))
}

if !opts.FloatToInt.valid() {
return nil, errors.New("cbor: invalid FloatToInt " + strconv.Itoa(int(opts.FloatToInt)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand Down Expand Up @@ -1157,6 +1184,7 @@ func (opts DecOptions) decMode() (*decMode, error) { //nolint:gocritic // ignore
binaryUnmarshaler: opts.BinaryUnmarshaler,
textUnmarshaler: opts.TextUnmarshaler,
jsonUnmarshalerTranscoder: opts.JSONUnmarshalerTranscoder,
floatToInt: opts.FloatToInt,
}

return &dm, nil
Expand Down Expand Up @@ -1238,6 +1266,7 @@ type decMode struct {
binaryUnmarshaler BinaryUnmarshalerMode
textUnmarshaler TextUnmarshalerMode
jsonUnmarshalerTranscoder Transcoder
floatToInt FloatToIntMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand Down Expand Up @@ -1280,6 +1309,7 @@ func (dm *decMode) DecOptions() DecOptions {
BinaryUnmarshaler: dm.binaryUnmarshaler,
TextUnmarshaler: dm.textUnmarshaler,
JSONUnmarshalerTranscoder: dm.jsonUnmarshalerTranscoder,
FloatToInt: dm.floatToInt,
}
}

Expand Down Expand Up @@ -1584,15 +1614,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
switch ai {
case additionalInformationAsFloat16:
f := float64(float16.Frombits(uint16(val)).Float32()) //nolint:gosec
return fillFloat(t, f, v)
return fillFloat(t, f, v, d.dm.floatToInt)

case additionalInformationAsFloat32:
f := float64(math.Float32frombits(uint32(val))) //nolint:gosec
return fillFloat(t, f, v)
return fillFloat(t, f, v, d.dm.floatToInt)

case additionalInformationAsFloat64:
f := math.Float64frombits(val)
return fillFloat(t, f, v)
return fillFloat(t, f, v, d.dm.floatToInt)

default: // ai <= 24
if d.dm.simpleValues.rejected[SimpleValue(val)] { //nolint:gosec
Expand Down Expand Up @@ -3144,7 +3174,7 @@ func fillBool(t cborType, val bool, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillFloat(t cborType, val float64, v reflect.Value) error {
func fillFloat(t cborType, val float64, v reflect.Value, fti FloatToIntMode) error {
switch v.Kind() {
case reflect.Float32, reflect.Float64:
if v.OverflowFloat(val) {
Expand All @@ -3157,6 +3187,52 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
v.SetFloat(val)
return nil
}

if fti != FloatToIntAllowExact {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

// Modf returns (NaN, NaN) for NaN and (+/-Inf, NaN) for +/-Inf, so
// frac != 0 is true in all cases.
i, frac := math.Modf(val)
if frac != 0 {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " is not an integral value",
}
}

// Range-check before converting to int64/uint64, because the Go spec
// makes float-to-integer conversion implementation-dependent when the
// value is out of range. MinInt64 (-2^63), 2^63, and 2^64 are all
// exact as float64 because they are powers of two.
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := int64(i)
if i < math.MinInt64 || i >= 1<<63 || v.OverflowInt(n) {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(),
}
}
v.SetInt(n)
return nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := uint64(i)
if i < 0 || i >= 1<<64 || v.OverflowUint(n) {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: strconv.FormatFloat(val, 'f', -1, 64) + " overflows " + v.Type().String(),
}
}
v.SetUint(n)
return nil
}

return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

Expand Down
Loading
Loading