diff --git a/CHANGELOG.md b/CHANGELOG.md index aaeb6eb61..3c375fb83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,12 @@ Changelog All notable changes to this project will be documented in this file. +## Unreleased + +### Added + +- Schema: Added a `Decimal` common type carrying precision and scale via a new `LogicalParams` struct, enabling lossless conversion between Avro, Parquet, and database `NUMBER`/`NUMERIC` decimals. Includes `NewDecimal`, `FormatDecimal`/`ParseDecimal`, and `DecimalParams.Format`/`Parse`/`ValidateValue` helpers, plus a `Common.Validate` entry point. `ParseFromAny` and `InferFromAny` now accept `encoding/json.Number` values, so schemas pipelined through `json.Decoder.UseNumber()` round-trip without precision loss. (@Jeffail) + ## 4.70.0 - 2026-04-02 ### Added diff --git a/public/schema/common.go b/public/schema/common.go index 9349b76f0..e920755a4 100644 --- a/public/schema/common.go +++ b/public/schema/common.go @@ -47,11 +47,24 @@ // This optimization is particularly useful in scenarios where schemas are // transmitted over the network or stored in external systems, as it eliminates // the need to parse and recalculate fingerprints on cache hits. +// +// # Parameterised logical types +// +// Some types carry parameters beyond what the type identifier conveys. For +// example, a Decimal requires a precision and a scale. These parameters are +// attached to the [Common] schema via the [Common.Logical] field, which holds +// a [LogicalParams] struct. Only the field within [LogicalParams] that +// corresponds to the schema's [Common.Type] should be set. +// +// Use [Common.Validate] to confirm a schema's parameters are internally +// consistent before relying on it. package schema import ( "crypto/sha256" "encoding/hex" + "encoding/json" + "errors" "fmt" "io" ) @@ -75,6 +88,14 @@ const ( Union CommonType = 12 Timestamp CommonType = 13 Any CommonType = 14 + Decimal CommonType = 15 +) + +// Decimal precision bounds. The upper bound matches the widest precision that +// can be represented losslessly across Avro, Parquet and Oracle NUMBER. +const ( + DecimalMinPrecision int32 = 1 + DecimalMaxPrecision int32 = 38 ) // String returns a human readable string representation of the type. @@ -108,6 +129,8 @@ func (t CommonType) String() string { return "TIMESTAMP" case Any: return "ANY" + case Decimal: + return "DECIMAL" default: return "UNKNOWN" } @@ -143,6 +166,8 @@ func typeFromStr(v string) (CommonType, error) { return Timestamp, nil case "ANY": return Any, nil + case "DECIMAL": + return Decimal, nil default: return 0, fmt.Errorf("unrecognised type string: %v", v) } @@ -157,6 +182,31 @@ type Common struct { Type CommonType Optional bool Children []Common + + // Logical holds parameters for parameterised types (e.g. Decimal). Only + // the field within LogicalParams that corresponds to Type should be + // populated; setting parameters that do not apply to the type is a + // validation error. + Logical *LogicalParams +} + +// LogicalParams groups the optional parameter blocks that parameterised +// CommonType values may carry. Each parameterised type has its own field; +// at most one is expected to be non-nil for any given Common schema. +type LogicalParams struct { + Decimal *DecimalParams +} + +// DecimalParams describes a fixed-precision decimal number. +// +// Precision is the total number of significant digits and must be in +// [DecimalMinPrecision, DecimalMaxPrecision]. Scale is the number of digits +// to the right of the decimal point and must be in [0, Precision]. These +// constraints describe the lossless intersection across Avro, Parquet and +// Oracle NUMBER. +type DecimalParams struct { + Precision int32 + Scale int32 } const ( @@ -165,6 +215,8 @@ const ( anyFieldOptional = "optional" anyFieldChildren = "children" anyFieldFingerprint = "fingerprint" + anyFieldPrecision = "precision" + anyFieldScale = "scale" ) // ToAny serializes the common schema into a generic Go value, with structured @@ -203,11 +255,32 @@ func (c *Common) ToAny() any { m[anyFieldChildren] = children } + if c.Type == Decimal && c.Logical != nil && c.Logical.Decimal != nil { + m[anyFieldPrecision] = int64(c.Logical.Decimal.Precision) + m[anyFieldScale] = int64(c.Logical.Decimal.Scale) + } + return m } -// ParseFromAny deserializes a common schema from a generic Go value. +// ParseFromAny deserializes a common schema from a generic Go value. The +// resulting schema is validated via [Common.Validate] before being returned. func ParseFromAny(v any) (Common, error) { + c, err := parseFromAnyNoValidate(v) + if err != nil { + return c, err + } + if err := c.Validate(); err != nil { + return c, err + } + return c, nil +} + +// parseFromAnyNoValidate performs the structural deserialisation without +// running [Common.Validate]. It is used internally so that recursive parsing +// of nested children does not validate each subtree O(depth) times; the +// public [ParseFromAny] entry point validates exactly once at the top level. +func parseFromAnyNoValidate(v any) (Common, error) { var c Common obj, ok := v.(map[string]any) @@ -236,13 +309,13 @@ func ParseFromAny(v any) (Common, error) { if optionalB, ok := optional.(bool); ok { c.Optional = optionalB } else { - return c, fmt.Errorf("expected field `optional` of type string, got %T", obj[anyFieldOptional]) + return c, fmt.Errorf("expected field `optional` of type bool, got %T", obj[anyFieldOptional]) } } if cArr, ok := obj[anyFieldChildren].([]any); ok { for i, cEle := range cArr { - cChild, err := ParseFromAny(cEle) + cChild, err := parseFromAnyNoValidate(cEle) if err != nil { return c, fmt.Errorf("child element %v: %w", i, err) } @@ -251,9 +324,123 @@ func ParseFromAny(v any) (Common, error) { } } + _, hasPrecision := obj[anyFieldPrecision] + _, hasScale := obj[anyFieldScale] + if hasPrecision || hasScale { + if c.Type != Decimal { + return c, fmt.Errorf("fields `precision` and `scale` are only valid for type DECIMAL, got %v", c.Type) + } + if !hasPrecision { + return c, errors.New("type DECIMAL requires field `precision`") + } + if !hasScale { + return c, errors.New("type DECIMAL requires field `scale`") + } + + precision, err := anyIntField(obj, anyFieldPrecision) + if err != nil { + return c, err + } + scale, err := anyIntField(obj, anyFieldScale) + if err != nil { + return c, err + } + + c.Logical = &LogicalParams{ + Decimal: &DecimalParams{ + Precision: precision, + Scale: scale, + }, + } + } else if c.Type == Decimal { + return c, errors.New("type DECIMAL requires fields `precision` and `scale`") + } + return c, nil } +// anyIntField extracts an integer-valued field from a map[string]any, +// accepting any of the integer or float numeric types that JSON-derived maps +// commonly produce. Float values must have no fractional part. +func anyIntField(obj map[string]any, key string) (int32, error) { + v, ok := obj[key] + if !ok { + return 0, fmt.Errorf("missing field `%s`", key) + } + + switch n := v.(type) { + case int: + return int32Bounded(int64(n), key) + case int32: + return n, nil + case int64: + return int32Bounded(n, key) + case float32: + if float32(int64(n)) != n { + return 0, fmt.Errorf("field `%s` must be an integer, got %v", key, n) + } + return int32Bounded(int64(n), key) + case float64: + if float64(int64(n)) != n { + return 0, fmt.Errorf("field `%s` must be an integer, got %v", key, n) + } + return int32Bounded(int64(n), key) + case json.Number: + i, err := n.Int64() + if err != nil { + return 0, fmt.Errorf("field `%s` must be an integer, got %v", key, n) + } + return int32Bounded(i, key) + default: + return 0, fmt.Errorf("expected field `%s` of integer type, got %T", key, v) + } +} + +func int32Bounded(n int64, key string) (int32, error) { + const maxInt32 = int64(^uint32(0) >> 1) + const minInt32 = -maxInt32 - 1 + if n < minInt32 || n > maxInt32 { + return 0, fmt.Errorf("field `%s` value %d does not fit in int32", key, n) + } + return int32(n), nil +} + +// Validate enforces the parameter invariants of parameterised types +// (currently only [Decimal]) and that no parameter blocks are attached to +// types that do not accept them. It recurses into [Common.Children]. +// +// Structural invariants — for example that an [Object] has children, or that +// a [Union] has more than one child — are not currently enforced; the +// validation surface may grow as new logical types arrive. +// +// Schemas constructed via [ParseFromAny] are validated automatically. Schemas +// constructed by struct literal should call Validate before being passed to +// converters or caches. +func (c *Common) Validate() error { + if c.Type == Decimal { + if c.Logical == nil || c.Logical.Decimal == nil { + return errors.New("type DECIMAL requires Logical.Decimal parameters") + } + d := c.Logical.Decimal + if d.Precision < DecimalMinPrecision || d.Precision > DecimalMaxPrecision { + return fmt.Errorf("decimal precision %d out of range [%d, %d]", d.Precision, DecimalMinPrecision, DecimalMaxPrecision) + } + if d.Scale < 0 || d.Scale > d.Precision { + return fmt.Errorf("decimal scale %d out of range [0, precision=%d]", d.Scale, d.Precision) + } + } else if c.Logical != nil && c.Logical.Decimal != nil { + return fmt.Errorf("Logical.Decimal parameters are only valid for type DECIMAL, got %v", c.Type) + } + + for i, child := range c.Children { + if err := child.Validate(); err != nil { + return fmt.Errorf("child %d (%q): %w", i, child.Name, err) + } + } + + return nil +} + // Fingerprint returns a deterministic hash identifier for the schema structure. // Two schemas with the same structure will produce the same fingerprint, // regardless of memory location. This is useful for caching schema conversions @@ -281,6 +468,12 @@ func (c *Common) writeFingerprint(w io.Writer) { fmt.Fprint(w, "O:0|") } + // Write parameters for parameterised types. Only emitted when present so + // that schemas of unparameterised types retain their existing fingerprint. + if c.Type == Decimal && c.Logical != nil && c.Logical.Decimal != nil { + fmt.Fprintf(w, "D:%d:%d|", c.Logical.Decimal.Precision, c.Logical.Decimal.Scale) + } + // Write children count and recursively fingerprint each child fmt.Fprintf(w, "C:%d|", len(c.Children)) for i, child := range c.Children { diff --git a/public/schema/common_test.go b/public/schema/common_test.go index f218bd6ad..a3a77148c 100644 --- a/public/schema/common_test.go +++ b/public/schema/common_test.go @@ -29,6 +29,7 @@ func TestSchemaStringify(t *testing.T) { {Input: Union, Output: "UNION"}, {Input: Timestamp, Output: "TIMESTAMP"}, {Input: Any, Output: "ANY"}, + {Input: Decimal, Output: "DECIMAL"}, {Input: zeroType, Output: "UNKNOWN"}, {Input: CommonType(-1), Output: "UNKNOWN"}, } { diff --git a/public/schema/decimal.go b/public/schema/decimal.go new file mode 100644 index 000000000..3303e5751 --- /dev/null +++ b/public/schema/decimal.go @@ -0,0 +1,189 @@ +// Copyright 2025 Redpanda Data, Inc. + +package schema + +import ( + "errors" + "fmt" + "math/big" + "strings" +) + +// NewDecimal constructs a Common schema for a fixed-precision decimal column +// and validates the precision and scale bounds. It exists so callers can avoid +// the triple-nested LogicalParams/DecimalParams struct literal at the call +// site. +func NewDecimal(name string, precision, scale int32, optional bool) (Common, error) { + c := Common{ + Name: name, + Type: Decimal, + Optional: optional, + Logical: &LogicalParams{ + Decimal: &DecimalParams{ + Precision: precision, + Scale: scale, + }, + }, + } + if err := c.Validate(); err != nil { + return Common{}, err + } + return c, nil +} + +// FormatDecimal renders an unscaled integer as the canonical decimal string +// described by the package's value contract: a leading minus sign for +// negatives, no leading plus, no leading zeros aside from a single "0" before +// the decimal point, exactly scale fractional digits, and no scientific +// notation. The scale parameter must be non-negative. +// +// Precision is not enforced here; use [DecimalParams.Format] when both +// precision and scale must be checked. +// +// Examples for scale=4: +// +// FormatDecimal(big.NewInt(12345), 4) // "1.2345" +// FormatDecimal(big.NewInt(0), 4) // "0.0000" +// FormatDecimal(big.NewInt(-1), 4) // "-0.0001" +// FormatDecimal(big.NewInt(12345), 0) // "12345" +func FormatDecimal(unscaled *big.Int, scale int32) (string, error) { + if unscaled == nil { + return "", errors.New("unscaled value must not be nil") + } + if scale < 0 { + return "", fmt.Errorf("scale must be non-negative, got %d", scale) + } + + sign := "" + if unscaled.Sign() < 0 { + sign = "-" + } + abs := new(big.Int).Abs(unscaled).String() + + if scale == 0 { + return sign + abs, nil + } + + // Pad so that there is at least one digit before the decimal point. + if int32(len(abs)) <= scale { + abs = strings.Repeat("0", int(scale)-len(abs)+1) + abs + } + + splitAt := int32(len(abs)) - scale + return sign + abs[:splitAt] + "." + abs[splitAt:], nil +} + +// ParseDecimal interprets s as a canonical decimal string and returns the +// unscaled integer at the given scale. Inputs with fewer fractional digits +// than scale are accepted and right-padded with zeros; inputs with more +// fractional digits than scale are rejected. +// +// Scientific notation, leading plus signs, thousands separators, multiple +// decimal points, and whitespace are not accepted. The integer part of the +// number is required (".5" is not accepted; use "0.5"). The scale parameter +// must be non-negative. +// +// Precision is not enforced here; use [DecimalParams.Parse] when both +// precision and scale must be checked. +func ParseDecimal(s string, scale int32) (*big.Int, error) { + if scale < 0 { + return nil, fmt.Errorf("scale must be non-negative, got %d", scale) + } + if s == "" { + return nil, errors.New("decimal string must not be empty") + } + + rest := s + sign := "" + switch rest[0] { + case '-': + sign = "-" + rest = rest[1:] + case '+': + return nil, errors.New("decimal string must not have a leading plus sign") + } + if rest == "" { + return nil, errors.New("decimal string has no digits") + } + + intPart, fracPart, hasDot := strings.Cut(rest, ".") + if hasDot && strings.Contains(fracPart, ".") { + return nil, errors.New("decimal string must contain at most one decimal point") + } + if intPart == "" { + return nil, errors.New("decimal string is missing the integer part") + } + + if err := requireDigits(intPart); err != nil { + return nil, err + } + if err := requireDigits(fracPart); err != nil { + return nil, err + } + + if int32(len(fracPart)) > scale { + return nil, fmt.Errorf("decimal string has %d fractional digits, exceeds scale %d", len(fracPart), scale) + } + + padded := fracPart + strings.Repeat("0", int(scale)-len(fracPart)) + raw := sign + intPart + padded + + n, ok := new(big.Int).SetString(raw, 10) + if !ok { + return nil, fmt.Errorf("failed to parse decimal value %q", s) + } + return n, nil +} + +func requireDigits(s string) error { + for _, r := range s { + if r < '0' || r > '9' { + return fmt.Errorf("decimal string contains non-digit %q", r) + } + } + return nil +} + +// Format renders the unscaled integer as a canonical decimal string at the +// configured scale, and rejects values whose magnitude exceeds the configured +// precision. +func (p DecimalParams) Format(unscaled *big.Int) (string, error) { + if err := p.ValidateValue(unscaled); err != nil { + return "", err + } + return FormatDecimal(unscaled, p.Scale) +} + +// Parse interprets s as a canonical decimal string at the configured scale, +// and rejects values whose magnitude exceeds the configured precision. +func (p DecimalParams) Parse(s string) (*big.Int, error) { + n, err := ParseDecimal(s, p.Scale) + if err != nil { + return nil, err + } + if err := p.ValidateValue(n); err != nil { + return nil, err + } + return n, nil +} + +// ValidateValue confirms that the magnitude of v has no more significant +// digits than the configured precision. The configured precision and scale +// are not themselves validated by this method; use [Common.Validate] for +// that. +func (p DecimalParams) ValidateValue(v *big.Int) error { + if v == nil { + return errors.New("decimal value must not be nil") + } + digits := decimalDigits(v) + if int32(digits) > p.Precision { + return fmt.Errorf("decimal value has %d significant digits, exceeds precision %d", digits, p.Precision) + } + return nil +} + +// decimalDigits returns the number of digits in the decimal representation of +// the absolute value of v. The digit count of zero is one. +func decimalDigits(v *big.Int) int { + return len(new(big.Int).Abs(v).String()) +} diff --git a/public/schema/decimal_test.go b/public/schema/decimal_test.go new file mode 100644 index 000000000..a4fe61431 --- /dev/null +++ b/public/schema/decimal_test.go @@ -0,0 +1,568 @@ +// Copyright 2025 Redpanda Data, Inc. + +package schema + +import ( + "bytes" + "encoding/json" + "math/big" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecimalToAnyEmitsParams(t *testing.T) { + c := Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + } + + m, ok := c.ToAny().(map[string]any) + require.True(t, ok) + + assert.Equal(t, "DECIMAL", m[anyFieldType]) + assert.Equal(t, "amount", m[anyFieldName]) + assert.Equal(t, int64(18), m[anyFieldPrecision]) + assert.Equal(t, int64(4), m[anyFieldScale]) +} + +func TestDecimalNonDecimalDoesNotEmitParams(t *testing.T) { + c := Common{Type: Int64, Name: "count"} + m, ok := c.ToAny().(map[string]any) + require.True(t, ok) + + _, hasPrecision := m[anyFieldPrecision] + _, hasScale := m[anyFieldScale] + assert.False(t, hasPrecision, "non-decimal types must not emit precision") + assert.False(t, hasScale, "non-decimal types must not emit scale") +} + +func TestDecimalRoundTrip(t *testing.T) { + original := Common{ + Type: Decimal, + Name: "balance", + Optional: true, + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 38, Scale: 10}}, + } + + parsed, err := ParseFromAny(original.ToAny()) + require.NoError(t, err) + + assert.Equal(t, original.Type, parsed.Type) + assert.Equal(t, original.Name, parsed.Name) + assert.Equal(t, original.Optional, parsed.Optional) + require.NotNil(t, parsed.Logical) + require.NotNil(t, parsed.Logical.Decimal) + assert.Equal(t, int32(38), parsed.Logical.Decimal.Precision) + assert.Equal(t, int32(10), parsed.Logical.Decimal.Scale) + assert.Equal(t, original.fingerprint(), parsed.fingerprint()) +} + +func TestDecimalRoundTripNested(t *testing.T) { + original := Common{ + Type: Object, + Name: "row", + Children: []Common{ + {Type: String, Name: "id"}, + { + Type: Decimal, + Name: "price", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 12, Scale: 2}}, + }, + { + Type: Decimal, + Name: "fee", + Optional: true, + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 8, Scale: 4}}, + }, + }, + } + + parsed, err := ParseFromAny(original.ToAny()) + require.NoError(t, err) + assert.Equal(t, original.fingerprint(), parsed.fingerprint()) +} + +func TestDecimalParseFromAnyJSONNumber(t *testing.T) { + // Simulate what happens when a caller decodes the Any form via + // json.Decoder.UseNumber() — precision and scale arrive as json.Number + // rather than float64. + in := map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "amount", + anyFieldPrecision: json.Number("20"), + anyFieldScale: json.Number("6"), + } + + c, err := ParseFromAny(in) + require.NoError(t, err) + require.NotNil(t, c.Logical) + require.NotNil(t, c.Logical.Decimal) + assert.Equal(t, int32(20), c.Logical.Decimal.Precision) + assert.Equal(t, int32(6), c.Logical.Decimal.Scale) +} + +func TestDecimalParseFromAnyJSONNumberFractional(t *testing.T) { + in := map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldPrecision: json.Number("10.5"), + anyFieldScale: json.Number("2"), + } + + _, err := ParseFromAny(in) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be an integer") +} + +func TestDecimalRoundTripThroughJSONUseNumber(t *testing.T) { + // End-to-end: serialise via ToAny, encode to JSON, decode with + // UseNumber, parse back through ParseFromAny. Exercises the contract + // for any caller that pipes schemas through a JSON layer with the + // numeric-precision-preserving decoder configuration. + original := Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 38, Scale: 10}}, + } + + encoded, err := json.Marshal(original.ToAny()) + require.NoError(t, err) + + dec := json.NewDecoder(bytes.NewReader(encoded)) + dec.UseNumber() + var decoded any + require.NoError(t, dec.Decode(&decoded)) + + parsed, err := ParseFromAny(decoded) + require.NoError(t, err) + assert.Equal(t, original.fingerprint(), parsed.fingerprint()) +} + +func TestDecimalParseFromAnyFloatPrecision(t *testing.T) { + // JSON unmarshalling produces float64s for numbers; ensure we accept them + // when they have no fractional part. + in := map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "amount", + anyFieldPrecision: float64(20), + anyFieldScale: float64(6), + } + + c, err := ParseFromAny(in) + require.NoError(t, err) + require.NotNil(t, c.Logical) + require.NotNil(t, c.Logical.Decimal) + assert.Equal(t, int32(20), c.Logical.Decimal.Precision) + assert.Equal(t, int32(6), c.Logical.Decimal.Scale) +} + +func TestDecimalValidate(t *testing.T) { + tests := []struct { + name string + schema Common + wantErr string + }{ + { + name: "valid", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 10, Scale: 2}}, + }, + }, + { + name: "valid scale equals precision", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 5, Scale: 5}}, + }, + }, + { + name: "valid scale zero", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 38, Scale: 0}}, + }, + }, + { + name: "missing logical", + schema: Common{Type: Decimal, Name: "x"}, + wantErr: "requires Logical.Decimal parameters", + }, + { + name: "missing decimal params", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{}, + }, + wantErr: "requires Logical.Decimal parameters", + }, + { + name: "precision below minimum", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 0, Scale: 0}}, + }, + wantErr: "precision 0 out of range", + }, + { + name: "precision above maximum", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 39, Scale: 0}}, + }, + wantErr: "precision 39 out of range", + }, + { + name: "negative scale", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 10, Scale: -1}}, + }, + wantErr: "scale -1 out of range", + }, + { + name: "scale exceeds precision", + schema: Common{ + Type: Decimal, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 5, Scale: 6}}, + }, + wantErr: "scale 6 out of range", + }, + { + name: "decimal params on non-decimal type", + schema: Common{ + Type: Int64, + Name: "x", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 5, Scale: 2}}, + }, + wantErr: "only valid for type DECIMAL", + }, + { + name: "child validation propagates", + schema: Common{ + Type: Object, + Name: "row", + Children: []Common{ + { + Type: Decimal, + Name: "bad", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 0, Scale: 0}}, + }, + }, + }, + wantErr: `child 0 ("bad")`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.schema.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestDecimalParseFromAnyErrors(t *testing.T) { + tests := []struct { + name string + input map[string]any + wantErr string + }{ + { + name: "decimal missing precision and scale", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + }, + wantErr: "requires fields `precision` and `scale`", + }, + { + name: "decimal missing scale", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldPrecision: int64(10), + }, + wantErr: "requires field `scale`", + }, + { + name: "decimal missing precision", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldScale: int64(2), + }, + wantErr: "requires field `precision`", + }, + { + name: "precision/scale on non-decimal", + input: map[string]any{ + anyFieldType: "INT64", + anyFieldName: "x", + anyFieldPrecision: int64(10), + anyFieldScale: int64(2), + }, + wantErr: "only valid for type DECIMAL", + }, + { + name: "non-integer precision", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldPrecision: 10.5, + anyFieldScale: int64(2), + }, + wantErr: "must be an integer", + }, + { + name: "wrong type for precision", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldPrecision: "10", + anyFieldScale: int64(2), + }, + wantErr: "expected field `precision` of integer type", + }, + { + name: "validation runs on parse", + input: map[string]any{ + anyFieldType: "DECIMAL", + anyFieldName: "x", + anyFieldPrecision: int64(50), + anyFieldScale: int64(2), + }, + wantErr: "out of range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseFromAny(tt.input) + require.Error(t, err) + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + }) + } +} + +func TestNewDecimal(t *testing.T) { + c, err := NewDecimal("amount", 18, 4, true) + require.NoError(t, err) + assert.Equal(t, Decimal, c.Type) + assert.Equal(t, "amount", c.Name) + assert.True(t, c.Optional) + require.NotNil(t, c.Logical) + require.NotNil(t, c.Logical.Decimal) + assert.Equal(t, int32(18), c.Logical.Decimal.Precision) + assert.Equal(t, int32(4), c.Logical.Decimal.Scale) +} + +func TestNewDecimalRejectsInvalid(t *testing.T) { + _, err := NewDecimal("x", 0, 0, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "precision 0 out of range") + + _, err = NewDecimal("x", 5, 6, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "scale 6 out of range") + + _, err = NewDecimal("x", 5, -1, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "scale -1 out of range") +} + +func TestFormatDecimal(t *testing.T) { + tests := []struct { + name string + unscaled string + scale int32 + want string + }{ + {"zero scale zero", "0", 0, "0"}, + {"zero scale four", "0", 4, "0.0000"}, + {"one scale zero", "1", 0, "1"}, + {"one scale four", "1", 4, "0.0001"}, + {"negative one scale four", "-1", 4, "-0.0001"}, + {"twelve thousand scale zero", "12345", 0, "12345"}, + {"twelve thousand scale two", "12345", 2, "123.45"}, + {"twelve thousand scale four", "12345", 4, "1.2345"}, + {"twelve thousand scale five", "12345", 5, "0.12345"}, + {"twelve thousand scale six", "12345", 6, "0.012345"}, + {"negative scale two", "-12345", 2, "-123.45"}, + {"max precision", "12345678901234567890123456789012345678", 0, "12345678901234567890123456789012345678"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, ok := new(big.Int).SetString(tt.unscaled, 10) + require.True(t, ok) + got, err := FormatDecimal(n, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestFormatDecimalErrors(t *testing.T) { + _, err := FormatDecimal(nil, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "must not be nil") + + _, err = FormatDecimal(big.NewInt(1), -1) + require.Error(t, err) + assert.Contains(t, err.Error(), "scale must be non-negative") +} + +func TestParseDecimal(t *testing.T) { + tests := []struct { + name string + input string + scale int32 + want string + }{ + {"zero scale zero", "0", 0, "0"}, + {"zero scale four", "0.0000", 4, "0"}, + {"one scale zero", "1", 0, "1"}, + {"one scale four", "0.0001", 4, "1"}, + {"negative scale four", "-0.0001", 4, "-1"}, + {"twelve thousand scale four", "1.2345", 4, "12345"}, + {"twelve thousand scale two", "123.45", 2, "12345"}, + {"twelve thousand scale zero", "12345", 0, "12345"}, + {"pad fewer fractional digits", "1.5", 4, "15000"}, + {"integer to scale two", "12345", 2, "1234500"}, + {"trailing dot allowed", "1.", 0, "1"}, + {"trailing dot with scale", "1.", 3, "1000"}, + {"negative integer", "-123", 0, "-123"}, + {"max precision", "12345678901234567890123456789012345678", 0, "12345678901234567890123456789012345678"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, err := ParseDecimal(tt.input, tt.scale) + require.NoError(t, err) + assert.Equal(t, tt.want, n.String()) + }) + } +} + +func TestParseDecimalErrors(t *testing.T) { + tests := []struct { + name string + input string + scale int32 + wantErr string + }{ + {"empty", "", 0, "must not be empty"}, + {"just minus", "-", 0, "no digits"}, + {"leading plus", "+1", 0, "must not have a leading plus"}, + {"missing integer part", ".5", 1, "missing the integer part"}, + {"two dots", "1.2.3", 1, "at most one decimal point"}, + {"non-digit", "1.2a", 1, "non-digit"}, + {"scientific notation", "1e5", 0, "non-digit"}, + {"whitespace", " 1.5", 4, "non-digit"}, + {"trailing whitespace", "1.5 ", 4, "non-digit"}, + {"thousands separator", "1,000", 0, "non-digit"}, + {"too many fractional digits", "1.23456", 4, "exceeds scale 4"}, + {"negative scale", "1.5", -1, "scale must be non-negative"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseDecimal(tt.input, tt.scale) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestDecimalRoundTripFormatParse(t *testing.T) { + values := []struct { + unscaled string + scale int32 + }{ + {"0", 0}, + {"0", 4}, + {"12345", 4}, + {"-12345", 4}, + {"1", 38}, + {"-1", 38}, + {"12345678901234567890123456789012345678", 0}, + } + + for _, v := range values { + t.Run(v.unscaled+"@"+itoa(v.scale), func(t *testing.T) { + n, ok := new(big.Int).SetString(v.unscaled, 10) + require.True(t, ok) + s, err := FormatDecimal(n, v.scale) + require.NoError(t, err) + parsed, err := ParseDecimal(s, v.scale) + require.NoError(t, err) + assert.Equal(t, 0, n.Cmp(parsed), "round trip mismatch: %s vs %s", n, parsed) + }) + } +} + +func itoa(v int32) string { + return new(big.Int).SetInt64(int64(v)).String() +} + +func TestDecimalParamsFormatRejectsOverflow(t *testing.T) { + p := DecimalParams{Precision: 5, Scale: 2} + + // Within precision. + s, err := p.Format(big.NewInt(99999)) + require.NoError(t, err) + assert.Equal(t, "999.99", s) + + // Exceeds precision. + _, err = p.Format(big.NewInt(123456)) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds precision 5") +} + +func TestDecimalParamsParseRejectsOverflow(t *testing.T) { + p := DecimalParams{Precision: 5, Scale: 2} + + n, err := p.Parse("999.99") + require.NoError(t, err) + assert.Equal(t, "99999", n.String()) + + _, err = p.Parse("9999.99") + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds precision 5") +} + +func TestDecimalParamsValidateValue(t *testing.T) { + p := DecimalParams{Precision: 5, Scale: 2} + + assert.NoError(t, p.ValidateValue(big.NewInt(0))) + assert.NoError(t, p.ValidateValue(big.NewInt(99999))) + assert.NoError(t, p.ValidateValue(big.NewInt(-99999))) + + err := p.ValidateValue(big.NewInt(100000)) + require.Error(t, err) + assert.Contains(t, err.Error(), "6 significant digits") + + err = p.ValidateValue(nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "must not be nil") +} diff --git a/public/schema/decimal_types.md b/public/schema/decimal_types.md new file mode 100644 index 000000000..c5775d362 --- /dev/null +++ b/public/schema/decimal_types.md @@ -0,0 +1,338 @@ +# Decimal types in `public/schema` + +This document describes the `Decimal` common type and its parameterised +representation, and lays out the contracts that schema-format converters and +data-source plugins must honour when handling decimal values. + +## Why decimals are a special case + +Most numeric types in the common schema (`Int32`, `Int64`, `Float32`, `Float64`) +have a fixed bit-width and need no further parameters. Decimals don't: a +decimal value's meaning depends on its **precision** (total significant digits) +and **scale** (digits to the right of the decimal point). The same byte +sequence `0x00 0x7B` encodes `123` at scale 0, `12.3` at scale 1, and `1.23` at +scale 2. + +Therefore the common schema needs to carry these parameters alongside the type +identifier, and every downstream converter and data-source plugin must agree on +how the parameters and the values they describe travel together. + +## Changes to the common schema + +### New type + +```go +const Decimal CommonType = 15 +``` + +`Decimal` joins the existing primitive and structural types and stringifies as +`"DECIMAL"`. + +### New parameter struct + +A new optional field is added to `Common` for parameterised types in general, +not only decimal: + +```go +type Common struct { + Name string + Type CommonType + Optional bool + Children []Common + Logical *LogicalParams // nil when no logical parameters are required +} + +type LogicalParams struct { + Decimal *DecimalParams + // Future parameterised logical types add their own pointer field here. +} + +type DecimalParams struct { + Precision int32 + Scale int32 +} +``` + +Only the `LogicalParams` field corresponding to `Common.Type` is allowed to be +non-nil. Setting `Logical.Decimal` on a non-`Decimal` schema is a validation +error. + +### Bounds + +```go +const ( + DecimalMinPrecision int32 = 1 + DecimalMaxPrecision int32 = 38 +) +``` + +Validation rules enforced by `Common.Validate()` and applied by +`ParseFromAny`: + +- `Precision ∈ [DecimalMinPrecision, DecimalMaxPrecision]` +- `Scale ∈ [0, Precision]` + +These bounds describe the **lossless intersection** across Avro `decimal`, +Parquet `DECIMAL`, and Oracle `NUMBER`. Oracle permits negative scale and +precisions up to its own internal limits, but those values cannot round-trip +through Avro or Parquet, so the common schema does not allow them. Sources that +encounter wider Oracle decimals should either narrow them or downgrade to +`String` and document the loss. + +### Serialisation in `ToAny` / `ParseFromAny` + +For decimals, `ToAny` adds two top-level fields to the map: + +```json +{ + "type": "DECIMAL", + "name": "amount", + "precision": 18, + "scale": 4, + "fingerprint": "..." +} +``` + +`ParseFromAny` requires both fields when `type` is `DECIMAL`, rejects them on +any other type, and runs full validation before returning. Numeric values are +accepted as `int`, `int32`, `int64`, `float32` or `float64`, the latter two +provided they have no fractional part — JSON unmarshalling tends to produce +`float64`s. + +### Fingerprinting + +`writeFingerprint` includes a `D::|` segment **only** when +the type is `Decimal`. Non-decimal schemas keep the byte-for-byte canonical +form they had before, so existing fingerprints (and cached conversions keyed by +them) remain stable. + +### Inference + +`InferFromAny` does not infer decimals. Go has no canonical decimal type and +there is no reliable way to recover precision and scale from a generic Go value +without context. Decimal schemas must be constructed explicitly by data-source +plugins from authoritative source metadata. + +## Contract for schema-format converters + +Converters live outside this package (Avro, Parquet, Iceberg, JSON Schema, +Protobuf, ...). When a converter encounters a `Decimal` common schema it +**must**: + +1. Read precision and scale from `c.Logical.Decimal`. Treat `c.Logical == nil` + or `c.Logical.Decimal == nil` as a programming error and return an error, + not a default. +2. Pick the format-native decimal representation that preserves precision and + scale exactly. See per-format guidance below. +3. Refuse precisions or scales the target format cannot represent rather than + silently truncating. The common schema's bounds are conservative, so most + targets will never need to reject; those that do must surface a clear + error. + +When **producing** a `Common` schema from a format-native schema, the converter +constructs `&LogicalParams{Decimal: &DecimalParams{...}}` from the source +precision and scale and runs `Common.Validate()` before returning. + +### Avro + +Avro's `decimal` is a logical type built on top of `bytes` or `fixed`. +Converters should: + +- For schemas read from Avro: take `precision` and `scale` from the logical + type annotation. If `scale` is absent, default it to `0` (Avro spec + default). +- For schemas written to Avro: prefer `bytes` as the underlying primitive + unless the conversion target requires `fixed` (e.g. for fixed-width on-wire + framing). When using `fixed`, compute `size = ceil((precision * log2(10) + + 1) / 8)`. +- Reject Avro schemas where `scale > precision` or `precision <= 0` — these + are invalid in Avro itself and would fail validation in the common schema + too. + +The on-wire Avro decimal value is a two's-complement signed big-endian +integer. Converters that operate on Avro records will need to multiply the +incoming value by `10^scale` (conceptually) to reconstruct the unscaled +integer, and divide on the way out. + +### Parquet + +Parquet's `DECIMAL` logical type wraps one of four physical types, chosen by +precision: + +| Precision range | Physical type | +|-----------------|---------------------------| +| 1 – 9 | `INT32` | +| 10 – 18 | `INT64` | +| 19 – 38 | `FIXED_LEN_BYTE_ARRAY` | +| arbitrary | `BYTE_ARRAY` | + +Converters should: + +- For schemas written to Parquet: select the smallest physical type capable of + representing the precision. `FIXED_LEN_BYTE_ARRAY` length is + `ceil((precision * log2(10) + 1) / 8)`. +- For schemas read from Parquet: require both `precision` and `scale` + annotations. Reject decimals encoded as bare `BYTE_ARRAY` without a logical + type annotation, since precision and scale are not recoverable from the + bytes alone. + +Parquet shares Avro's two's-complement big-endian wire format for the +byte-backed cases, and uses native two's-complement for the integer-backed +cases. + +### Oracle / databases with `NUMBER(p, s)` + +Sources reading from `NUMBER(p, s)` set `Precision = p` and `Scale = s`. The +following conditions must be handled explicitly: + +- `NUMBER` with **no** declared precision (Oracle's "floating decimal"): there + is no fixed precision to record. Sources must either pick a sentinel + precision (e.g. 38) and warn, or downgrade to `String`. +- `NUMBER` with declared precision but **no** scale: `Scale = 0`. +- `NUMBER` with **negative** scale: not supported. Sources must either round + to scale 0, downgrade to `String`, or refuse the column. + +### Postgres `NUMERIC` / MySQL `DECIMAL` / `NUMERIC` + +These map directly: precision and scale from the column metadata translate +straight to `DecimalParams`. Both databases enforce `0 ≤ scale ≤ precision`, +so values from these sources will always validate. + +`NUMERIC` columns with no precision (Postgres "arbitrary precision") fall into +the same bucket as undeclared Oracle `NUMBER`: pick a precision and warn, or +downgrade to `String`. + +### JSON Schema + +JSON Schema has no native decimal. Converters should map `Decimal` to +`{"type": "string", "pattern": ...}` with a regex that matches the precision +and scale, and document the loss of arithmetic semantics in the +roundtripped schema. Inbound conversion (JSON Schema → common) cannot recover +`Decimal` and should retain the value as `String`. + +## Contract for data-source plugins + +Data-source plugins (CDC inputs like `mysql_cdc`, `postgres_cdc`, `oracle_cdc`, +batch inputs like `sql_select`, etc.) emit two things: a **schema** describing +each column, and **values** for each row. + +### Producing the schema + +When a source identifies a column as a fixed-precision decimal, prefer the +constructor helper: + +```go +col, err := schema.NewDecimal("amount", precisionFromSource, scaleFromSource, nullable) +if err != nil { + return err +} +``` + +`NewDecimal` validates the precision and scale once at schema-discovery time. +Per-row validation is unnecessary and should be avoided on hot paths. + +The constructor is shorthand for the equivalent struct literal, which remains +available for cases that need it (e.g. constructing a parent [Common] schema +in a single expression): + +```go +col := schema.Common{ + Name: "amount", + Type: schema.Decimal, + Optional: nullable, + Logical: &schema.LogicalParams{ + Decimal: &schema.DecimalParams{ + Precision: precisionFromSource, + Scale: scaleFromSource, + }, + }, +} +if err := col.Validate(); err != nil { + return err +} +``` + +### Producing values + +The benthos message body that travels alongside the schema should encode each +decimal value in **canonical string form**: + +- A leading minus sign for negative values; no leading plus sign. +- No leading zeros except for the single `0` before a decimal point. +- A decimal point appears if and only if `scale > 0`. +- Exactly `scale` digits after the decimal point — sources must pad with + trailing zeros if necessary so that `"1.5"` for a `(p, 4)` column is + emitted as `"1.5000"`. +- No scientific notation, thousands separators, or whitespace. + +Examples for `Precision=18, Scale=4`: + +| Source value | Emitted string | +|--------------|----------------| +| `12345` | `"12345.0000"` | +| `-0.1` | `"-0.1000"` | +| `0` | `"0.0000"` | + +Strings are chosen as the canonical form because they: + +- Survive JSON round-trips without floating-point loss. +- Pass cleanly through Bloblang's existing string-handling primitives. +- Can be parsed by every downstream converter (Avro, Parquet, ...) into the + format-native unscaled integer when needed. +- Don't bind benthos to a specific Go decimal library. + +To produce and consume the canonical form consistently across plugins, use +the helpers in this package: + +```go +// Producing a value (e.g. in a CDC source after reading the raw decimal): +unscaled := big.NewInt(15000) +str, err := schema.FormatDecimal(unscaled, scale) // "1.5000" at scale 4 + +// Or, with precision enforcement: +params := schema.DecimalParams{Precision: 18, Scale: 4} +str, err := params.Format(unscaled) + +// Consuming a value (e.g. in a converter writing to Avro/Parquet): +unscaled, err := schema.ParseDecimal("1.5000", scale) // big.NewInt(15000) +unscaled, err := params.Parse("1.5000") // also enforces precision +``` + +Plugins that roll their own formatting are likely to drift from the contract +(scientific notation, trailing-zero handling, sign-zero, leading zeros). Use +the helpers. + +### Optional fast paths for converters + +Converters that want to avoid string parsing on hot paths **may** accept +additional value forms — but the canonical string form is mandatory and is +what data-source plugins are required to emit. Suggested optional forms a +converter can opt in to: + +- `[]byte` containing the two's-complement big-endian unscaled integer + (matches the Avro/Parquet wire format). +- `*big.Int` containing the unscaled integer (the form returned by + `schema.ParseDecimal` and accepted by `schema.FormatDecimal`). + +These fast paths are **opt-in for the converter, not optional for the +source**. A new data-source plugin that does not emit canonical strings is +non-conformant. + +### Null values + +A nullable decimal column emits a Go `nil` value. The schema's `Optional` +field carries the nullability information; the value form is unchanged +otherwise. + +## Migration notes for existing converters and sources + +This change is additive: + +- Existing schemas that did not previously contain decimals are unaffected. + Their fingerprints are byte-for-byte identical to before, so cached + conversions remain valid. +- Existing converters that do not handle the `Decimal` type should continue to + reject it with `"unsupported type"` until updated. +- Existing data sources that previously surfaced decimal columns as `String` + may continue to do so for backwards compatibility, but should migrate to + emitting `Decimal` schemas with canonical-string values when possible. diff --git a/public/schema/fingerprint_test.go b/public/schema/fingerprint_test.go index 0c2e0a7fe..664c52408 100644 --- a/public/schema/fingerprint_test.go +++ b/public/schema/fingerprint_test.go @@ -160,6 +160,48 @@ func TestFingerprint(t *testing.T) { }, shouldMatch: false, }, + { + name: "identical decimal params", + schema1: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + }, + schema2: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + }, + shouldMatch: true, + }, + { + name: "different decimal precision", + schema1: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + }, + schema2: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 20, Scale: 4}}, + }, + shouldMatch: false, + }, + { + name: "different decimal scale", + schema1: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + }, + schema2: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 6}}, + }, + shouldMatch: false, + }, } for _, tt := range tests { @@ -215,27 +257,38 @@ func TestFingerprintDeterministic(t *testing.T) { } func TestFingerprintAllTypes(t *testing.T) { - types := []CommonType{ - Boolean, Int32, Int64, Float32, Float64, - String, ByteArray, Object, Map, Array, - Null, Union, Timestamp, Any, + schemas := []Common{ + {Type: Boolean, Name: "test"}, + {Type: Int32, Name: "test"}, + {Type: Int64, Name: "test"}, + {Type: Float32, Name: "test"}, + {Type: Float64, Name: "test"}, + {Type: String, Name: "test"}, + {Type: ByteArray, Name: "test"}, + {Type: Object, Name: "test"}, + {Type: Map, Name: "test"}, + {Type: Array, Name: "test"}, + {Type: Null, Name: "test"}, + {Type: Union, Name: "test"}, + {Type: Timestamp, Name: "test"}, + {Type: Any, Name: "test"}, + {Type: Decimal, Name: "test", Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 10, Scale: 2}}}, } fingerprints := make(map[string]CommonType) - for _, typ := range types { - schema := Common{Type: typ, Name: "test"} + for _, schema := range schemas { fp := schema.fingerprint() if fp == "" { - t.Errorf("fingerprint for type %v should not be empty", typ) + t.Errorf("fingerprint for type %v should not be empty", schema.Type) } if existing, exists := fingerprints[fp]; exists { - t.Errorf("fingerprint collision between types %v and %v", existing, typ) + t.Errorf("fingerprint collision between types %v and %v", existing, schema.Type) } - fingerprints[fp] = typ + fingerprints[fp] = schema.Type } } @@ -277,6 +330,14 @@ func TestToAnyIncludesFingerprint(t *testing.T) { Name: "payload", }, }, + { + name: "decimal schema", + schema: Common{ + Type: Decimal, + Name: "amount", + Logical: &LogicalParams{Decimal: &DecimalParams{Precision: 18, Scale: 4}}, + }, + }, { name: "deeply nested schema", schema: Common{ diff --git a/public/schema/infer_from_any.go b/public/schema/infer_from_any.go index 204a73b9e..2abfbdee3 100644 --- a/public/schema/infer_from_any.go +++ b/public/schema/infer_from_any.go @@ -3,6 +3,7 @@ package schema import ( + "encoding/json" "fmt" "sort" "time" @@ -22,6 +23,17 @@ func inferFromAny(name string, v any) (Common, error) { c.Type = Float32 case float64: c.Type = Float64 + case json.Number: + // json.Number is produced by json.Decoder.UseNumber(); it has no + // int-vs-float discriminator, so try integer parsing first and fall + // back to float. + if _, err := t.Int64(); err == nil { + c.Type = Int64 + } else if _, err := t.Float64(); err == nil { + c.Type = Float64 + } else { + return c, fmt.Errorf(" json.Number value %q is not parseable as int64 or float64", string(t)) + } case []byte: c.Type = ByteArray case string: @@ -71,7 +83,13 @@ func inferFromAny(name string, v any) (Common, error) { // InferFromAny attempts to infer a common schema from any Go value. This // process fails if the value, or any children of a provided map/slice, are not // within the following subset of Go types: bool, int, int32, int64, float32, -// float64, []byte, string, map[string]any, []any. +// float64, [encoding/json.Number], []byte, string, map[string]any, []any. +// +// [encoding/json.Number] values are inferred as Int64 when they parse as an +// integer and as Float64 otherwise. +// +// Parameterised logical types (e.g. Decimal) cannot be inferred from generic Go +// values and must be constructed explicitly. // // All values will be recorded as non-optional. func InferFromAny(v any) (Common, error) { diff --git a/public/schema/infer_from_any_test.go b/public/schema/infer_from_any_test.go index bf32760dd..c8948da8f 100644 --- a/public/schema/infer_from_any_test.go +++ b/public/schema/infer_from_any_test.go @@ -3,6 +3,7 @@ package schema import ( + "encoding/json" "testing" "time" @@ -120,6 +121,25 @@ func TestFromAnySchema(t *testing.T) { }, ErrContains: "mismatched array types", }, + { + Name: "json.Number integer", + Input: json.Number("12345"), + Output: Common{ + Type: Int64, + }, + }, + { + Name: "json.Number float", + Input: json.Number("1.5"), + Output: Common{ + Type: Float64, + }, + }, + { + Name: "json.Number invalid", + Input: json.Number("not-a-number"), + ErrContains: "not parseable as int64 or float64", + }, } { t.Run(test.Name, func(t *testing.T) { res, err := InferFromAny(test.Input)