From 2116422db13583d921a6bb64766b96f28cf79923 Mon Sep 17 00:00:00 2001 From: Brad Peabody Date: Tue, 24 Feb 2026 08:03:04 -0800 Subject: [PATCH] implemented UBJ and JSON models for latest XGBoost --- internal/ubjdecode/decode.go | 536 ++++++++++++++++++ internal/xgjson/detect.go | 121 ++++ internal/xgjson/detect_test.go | 110 ++++ internal/xgjson/testdata/README.md | 22 + internal/xgjson/testdata/generate.go | 3 + .../xgjson/testdata/test_binary_logistic.json | 1 + .../xgjson/testdata/test_binary_logistic.py | 74 +++ .../xgjson/testdata/test_binary_logistic.ubj | Bin 0 -> 11883 bytes .../test_binary_logistic_expected.json | 47 ++ internal/xgjson/testdata/test_multiclass.json | 1 + internal/xgjson/testdata/test_multiclass.py | 97 ++++ internal/xgjson/testdata/test_multiclass.ubj | Bin 0 -> 34738 bytes .../testdata/test_multiclass_expected.json | 92 +++ internal/xgjson/testdata/test_poisson.json | 1 + internal/xgjson/testdata/test_poisson.py | 69 +++ internal/xgjson/testdata/test_poisson.ubj | Bin 0 -> 12509 bytes .../testdata/test_poisson_expected.json | 47 ++ internal/xgjson/testdata/test_regression.json | 1 + internal/xgjson/testdata/test_regression.py | 65 +++ internal/xgjson/testdata/test_regression.ubj | Bin 0 -> 12498 bytes .../testdata/test_regression_expected.json | 47 ++ internal/xgjson/testdata/train.py | 21 + internal/xgjson/xgjson_io.go | 468 +++++++++++++++ xgensemble_io.go | 38 +- xgensemble_json_io.go | 303 ++++++++++ xgensemble_json_test.go | 318 +++++++++++ 26 files changed, 2481 insertions(+), 1 deletion(-) create mode 100644 internal/ubjdecode/decode.go create mode 100644 internal/xgjson/detect.go create mode 100644 internal/xgjson/detect_test.go create mode 100644 internal/xgjson/testdata/README.md create mode 100644 internal/xgjson/testdata/generate.go create mode 100644 internal/xgjson/testdata/test_binary_logistic.json create mode 100644 internal/xgjson/testdata/test_binary_logistic.py create mode 100644 internal/xgjson/testdata/test_binary_logistic.ubj create mode 100644 internal/xgjson/testdata/test_binary_logistic_expected.json create mode 100644 internal/xgjson/testdata/test_multiclass.json create mode 100644 internal/xgjson/testdata/test_multiclass.py create mode 100644 internal/xgjson/testdata/test_multiclass.ubj create mode 100644 internal/xgjson/testdata/test_multiclass_expected.json create mode 100644 internal/xgjson/testdata/test_poisson.json create mode 100644 internal/xgjson/testdata/test_poisson.py create mode 100644 internal/xgjson/testdata/test_poisson.ubj create mode 100644 internal/xgjson/testdata/test_poisson_expected.json create mode 100644 internal/xgjson/testdata/test_regression.json create mode 100644 internal/xgjson/testdata/test_regression.py create mode 100644 internal/xgjson/testdata/test_regression.ubj create mode 100644 internal/xgjson/testdata/test_regression_expected.json create mode 100644 internal/xgjson/testdata/train.py create mode 100644 internal/xgjson/xgjson_io.go create mode 100644 xgensemble_json_io.go create mode 100644 xgensemble_json_test.go diff --git a/internal/ubjdecode/decode.go b/internal/ubjdecode/decode.go new file mode 100644 index 0000000..f9eb870 --- /dev/null +++ b/internal/ubjdecode/decode.go @@ -0,0 +1,536 @@ +// Package ubjdecode implements a standalone Universal Binary JSON (UBJSON) decoder. +// Zero external dependencies — uses only encoding/binary and io from stdlib. +// +// UBJSON spec: https://ubjson.org/ +// XGBoost uses UBJSON for its .ubj model format. +package ubjdecode + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +// DecodeValue reads one UBJSON value from r and returns it as a native Go value: +// +// map[string]interface{} for objects +// []interface{} for untyped/mixed arrays +// []int32, []int64, []float32, []float64 for optimized typed arrays +// string, bool, int64, float64, nil for scalars +func DecodeValue(r io.Reader) (interface{}, error) { + marker, err := readByte(r) + if err != nil { + return nil, err + } + return decodeMarker(r, marker) +} + +func decodeMarker(r io.Reader, marker byte) (interface{}, error) { + for marker == 'N' { // noop — skip and read next + var err error + marker, err = readByte(r) + if err != nil { + return nil, err + } + } + + switch marker { + case 'Z': // null + return nil, nil + case 'T': // true + return true, nil + case 'F': // false + return false, nil + case 'i': // int8 + b, err := readByte(r) + if err != nil { + return nil, err + } + return int64(int8(b)), nil + case 'U': // uint8 + b, err := readByte(r) + if err != nil { + return nil, err + } + return int64(b), nil + case 'I': // int16 big-endian + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return int64(v), nil + case 'l': // int32 big-endian + var v int32 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return int64(v), nil + case 'L': // int64 big-endian + var v int64 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return v, nil + case 'd': // float32 big-endian + var bits uint32 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + return float64(math.Float32frombits(bits)), nil + case 'D': // float64 big-endian + var bits uint64 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + return math.Float64frombits(bits), nil + case 'C': // char — single byte as string + b, err := readByte(r) + if err != nil { + return nil, err + } + return string([]byte{b}), nil + case 'S': // string — length marker + bytes + n, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("string length: %w", err) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return string(buf), nil + case 'H': // high-precision number — decode as string + n, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("high-precision length: %w", err) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return string(buf), nil + case '[': // array + return decodeArray(r) + case '{': // object + return decodeObject(r) + default: + return nil, fmt.Errorf("ubjdecode: unknown marker byte 0x%02X ('%c')", marker, marker) + } +} + +// readByte reads a single byte from r. +func readByte(r io.Reader) (byte, error) { + var buf [1]byte + _, err := io.ReadFull(r, buf[:]) + return buf[0], err +} + +// readLength reads an integer length value (used for string lengths, array/object counts). +// The next byte is a type marker ('i','U','I','l','L') indicating the integer width. +func readLength(r io.Reader) (int64, error) { + marker, err := readByte(r) + if err != nil { + return 0, err + } + switch marker { + case 'i': + b, err := readByte(r) + return int64(int8(b)), err + case 'U': + b, err := readByte(r) + return int64(b), err + case 'I': + var v int16 + err := binary.Read(r, binary.BigEndian, &v) + return int64(v), err + case 'l': + var v int32 + err := binary.Read(r, binary.BigEndian, &v) + return int64(v), err + case 'L': + var v int64 + err := binary.Read(r, binary.BigEndian, &v) + return v, err + default: + return 0, fmt.Errorf("ubjdecode: unexpected length marker 0x%02X ('%c')", marker, marker) + } +} + +// decodeArray decodes a UBJSON array. The '[' marker has already been consumed. +// Handles all four container forms: +// 1. [$][type][#][len] — typed + counted → returns typed Go slice +// 2. [#][len] — counted only → returns []interface{} +// 3. [$][type] + ']' — typed, terminated +// 4. ']' terminator — untyped, terminated → returns []interface{} +func decodeArray(r io.Reader) (interface{}, error) { + next, err := readByte(r) + if err != nil { + return nil, err + } + + // Typed array: [$][type][#][count] + if next == '$' { + typeMarker, err := readByte(r) + if err != nil { + return nil, err + } + // Peek for '#' + after, err := readByte(r) + if err != nil { + return nil, err + } + if after == '#' { + // typed + counted: fast path + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("typed array count: %w", err) + } + return readTypedArray(r, typeMarker, int(count)) + } + // typed + unterminated: read until ']' + // `after` is the first element's marker? No — in typed arrays the element + // marker is NOT repeated; `after` here should be ']' or data bytes. + // Actually when typed without count, each element has no marker prefix — + // the type is fixed. So `after` is the first data byte of the first element. + // We must decode accordingly. + return readTypedArrayUntilEnd(r, typeMarker, after) + } + + // Counted only: [#][len] + if next == '#' { + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("counted array length: %w", err) + } + result := make([]interface{}, 0, int(count)) + for i := int64(0); i < count; i++ { + v, err := DecodeValue(r) + if err != nil { + return nil, fmt.Errorf("array element %d: %w", i, err) + } + result = append(result, v) + } + return result, nil + } + + // Unterminated untyped: read until ']' + if next == ']' { + return []interface{}{}, nil + } + // next is the marker of the first element + result := []interface{}{} + v, err := decodeMarker(r, next) + if err != nil { + return nil, fmt.Errorf("array element 0: %w", err) + } + result = append(result, v) + for { + marker, err := readByte(r) + if err != nil { + return nil, err + } + if marker == ']' { + break + } + v, err := decodeMarker(r, marker) + if err != nil { + return nil, fmt.Errorf("array element: %w", err) + } + result = append(result, v) + } + return result, nil +} + +// readTypedArray reads `count` elements of a fixed type without per-element markers. +// Returns typed Go slices for performance-critical XGBoost arrays. +func readTypedArray(r io.Reader, typeMarker byte, count int) (interface{}, error) { + switch typeMarker { + case 'i': // int8 → []int32 + buf := make([]int8, count) + for i := range buf { + b, err := readByte(r) + if err != nil { + return nil, err + } + buf[i] = int8(b) + } + out := make([]int32, count) + for i, v := range buf { + out[i] = int32(v) + } + return out, nil + case 'U': // uint8 → []int32 + buf := make([]byte, count) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + out := make([]int32, count) + for i, v := range buf { + out[i] = int32(v) + } + return out, nil + case 'I': // int16 → []int32 + out := make([]int32, count) + for i := range out { + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + out[i] = int32(v) + } + return out, nil + case 'l': // int32 → []int32 + out := make([]int32, count) + for i := range out { + if err := binary.Read(r, binary.BigEndian, &out[i]); err != nil { + return nil, err + } + } + return out, nil + case 'L': // int64 → []int64 + out := make([]int64, count) + for i := range out { + if err := binary.Read(r, binary.BigEndian, &out[i]); err != nil { + return nil, err + } + } + return out, nil + case 'd': // float32 → []float32 + out := make([]float32, count) + for i := range out { + var bits uint32 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + out[i] = math.Float32frombits(bits) + } + return out, nil + case 'D': // float64 → []float64 + out := make([]float64, count) + for i := range out { + var bits uint64 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + out[i] = math.Float64frombits(bits) + } + return out, nil + default: + // Fall back to []interface{} for other typed arrays + result := make([]interface{}, 0, count) + for i := 0; i < count; i++ { + v, err := decodeMarker(r, typeMarker) + if err != nil { + return nil, fmt.Errorf("typed array element %d: %w", i, err) + } + result = append(result, v) + } + return result, nil + } +} + +// readTypedArrayUntilEnd reads a typed array without a count, terminated by ']'. +// firstDataByte is the first raw data byte already consumed (not a marker). +func readTypedArrayUntilEnd(r io.Reader, typeMarker byte, firstDataByte byte) (interface{}, error) { + // For typed arrays without count, we need to reconstruct the first element + // and then keep reading until we see ']'. + // Strategy: build []interface{} since we don't know the count. + result := []interface{}{} + + // Decode first element from firstDataByte + first, err := decodeSingleFromByte(typeMarker, firstDataByte, r) + if err != nil { + return nil, err + } + result = append(result, first) + + for { + // For typed arrays, next byte is raw data (no marker), or ']' to end + b, err := readByte(r) + if err != nil { + return nil, err + } + if b == ']' { + break + } + v, err := decodeSingleFromByte(typeMarker, b, r) + if err != nil { + return nil, err + } + result = append(result, v) + } + return result, nil +} + +// decodeSingleFromByte decodes one element of a typed array given the first +// byte already consumed and any remaining bytes read from r. +func decodeSingleFromByte(typeMarker byte, firstByte byte, r io.Reader) (interface{}, error) { + switch typeMarker { + case 'i': + return int64(int8(firstByte)), nil + case 'U': + return int64(firstByte), nil + case 'I': + var buf [1]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return nil, err + } + v := int16(firstByte)<<8 | int16(buf[0]) + return int64(v), nil + case 'l': + var rest [3]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + v := int32(firstByte)<<24 | int32(rest[0])<<16 | int32(rest[1])<<8 | int32(rest[2]) + return int64(v), nil + case 'L': + var rest [7]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + v := int64(firstByte)<<56 | int64(rest[0])<<48 | int64(rest[1])<<40 | + int64(rest[2])<<32 | int64(rest[3])<<24 | int64(rest[4])<<16 | + int64(rest[5])<<8 | int64(rest[6]) + return v, nil + case 'd': + var rest [3]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + bits := uint32(firstByte)<<24 | uint32(rest[0])<<16 | uint32(rest[1])<<8 | uint32(rest[2]) + return float64(math.Float32frombits(bits)), nil + case 'D': + var rest [7]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + bits := uint64(firstByte)<<56 | uint64(rest[0])<<48 | uint64(rest[1])<<40 | + uint64(rest[2])<<32 | uint64(rest[3])<<24 | uint64(rest[4])<<16 | + uint64(rest[5])<<8 | uint64(rest[6]) + return math.Float64frombits(bits), nil + default: + // treat firstByte as the marker + return decodeMarker(r, firstByte) + } +} + +// decodeObject decodes a UBJSON object. The '{' marker has already been consumed. +// Returns map[string]interface{}. +func decodeObject(r io.Reader) (map[string]interface{}, error) { + result := make(map[string]interface{}) + + // Check for optimized form: [{][#][len] or [{][$][type][#][len] + next, err := readByte(r) + if err != nil { + return nil, err + } + + if next == '#' { + // counted object + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("object count: %w", err) + } + for i := int64(0); i < count; i++ { + k, v, err := readKeyValue(r) + if err != nil { + return nil, fmt.Errorf("object key/value %d: %w", i, err) + } + result[k] = v + } + return result, nil + } + + if next == '}' { + return result, nil + } + + // Unoptimized: next byte is the length marker of the first key + // In UBJSON objects, keys are NOT prefixed with 'S' — just length_marker + bytes + k, v, err := readKeyValueFromLenMarker(r, next) + if err != nil { + return nil, fmt.Errorf("object first key/value: %w", err) + } + result[k] = v + + for { + marker, err := readByte(r) + if err != nil { + return nil, err + } + if marker == '}' { + break + } + k, v, err := readKeyValueFromLenMarker(r, marker) + if err != nil { + return nil, fmt.Errorf("object key/value: %w", err) + } + result[k] = v + } + return result, nil +} + +// readKeyValue reads one key-value pair from an object. +// The key length marker has NOT been consumed yet. +func readKeyValue(r io.Reader) (string, interface{}, error) { + lenMarker, err := readByte(r) + if err != nil { + return "", nil, err + } + return readKeyValueFromLenMarker(r, lenMarker) +} + +// readKeyValueFromLenMarker reads one key-value pair where lenMarker is the +// already-consumed first byte of the key length. +func readKeyValueFromLenMarker(r io.Reader, lenMarker byte) (string, interface{}, error) { + // lenMarker is the type marker for the key length + var keyLen int64 + switch lenMarker { + case 'i': + b, err := readByte(r) + if err != nil { + return "", nil, err + } + keyLen = int64(int8(b)) + case 'U': + b, err := readByte(r) + if err != nil { + return "", nil, err + } + keyLen = int64(b) + case 'I': + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = int64(v) + case 'l': + var v int32 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = int64(v) + case 'L': + var v int64 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = v + default: + return "", nil, fmt.Errorf("ubjdecode: unexpected key length marker 0x%02X ('%c')", lenMarker, lenMarker) + } + + keyBuf := make([]byte, keyLen) + if _, err := io.ReadFull(r, keyBuf); err != nil { + return "", nil, fmt.Errorf("reading key: %w", err) + } + key := string(keyBuf) + + val, err := DecodeValue(r) + if err != nil { + return "", nil, fmt.Errorf("value for key %q: %w", key, err) + } + return key, val, nil +} diff --git a/internal/xgjson/detect.go b/internal/xgjson/detect.go new file mode 100644 index 0000000..3250ded --- /dev/null +++ b/internal/xgjson/detect.go @@ -0,0 +1,121 @@ +package xgjson + +import ( + "bufio" + "encoding/binary" + "bytes" +) + +// LooksLikeJSON peeks at the buffered reader and returns true if the content +// appears to be an XGBoost JSON model file. It does not advance the reader. +// +// Heuristic: the first non-whitespace byte must be '{', and the string +// "\"learner\"" (the mandatory top-level key in every XGBoost model) must +// appear within the first 512 peeked bytes. +func LooksLikeJSON(r *bufio.Reader) (bool, error) { + buf, err := r.Peek(512) + if len(buf) == 0 { + return false, err + } + + // Find first non-whitespace byte; it must be '{' + firstNonWS := -1 + for i, b := range buf { + if b != ' ' && b != '\t' && b != '\r' && b != '\n' { + firstNonWS = i + break + } + } + if firstNonWS < 0 || buf[firstNonWS] != '{' { + return false, nil + } + + // Confirm the mandatory top-level key is present in the peeked window. + // Every XGBoost JSON model has "learner" as its first object key. + return bytes.Contains(buf, []byte(`"learner"`)), nil +} + +// LooksLikeUBJ peeks at the buffered reader and returns true if the content +// appears to be an XGBoost UBJ (Universal Binary JSON) model file. It does +// not advance the reader. +// +// Heuristic: UBJ objects start with '{' (0x7B), and the first key in an +// XGBoost UBJ model is always "learner". We verify the first byte is '{' and +// then parse the UBJ key-length encoding + key bytes from the peek buffer, +// checking that the first key is exactly "learner". +// +// UBJ key lengths are encoded as: , where the +// marker is one of: 'i' (int8), 'U' (uint8), 'I' (int16 BE), 'l' (int32 BE), +// 'L' (int64 BE). XGBoost currently uses 'L', so we handle all widths for +// robustness. +func LooksLikeUBJ(r *bufio.Reader) (bool, error) { + // 32 bytes is more than enough: 1 ('{') + 1 (marker) + 8 (int64) + 7 ("learner") = 17 + buf, err := r.Peek(32) + if len(buf) < 2 { + return false, err + } + + if buf[0] != '{' { + return false, nil + } + + // Parse the key length from buf[1:] + keyLen, headerLen, ok := ubjReadLength(buf[1:]) + if !ok { + return false, nil + } + + // "learner" is 7 bytes + if keyLen != 7 { + return false, nil + } + + start := 1 + headerLen + end := start + 7 + if end > len(buf) { + return false, nil + } + + return bytes.Equal(buf[start:end], []byte("learner")), nil +} + +// ubjReadLength parses a UBJ integer-length from b (starting at b[0] which is +// the type marker). Returns (value, bytesConsumed, ok). +func ubjReadLength(b []byte) (value int, consumed int, ok bool) { + if len(b) < 1 { + return 0, 0, false + } + marker := b[0] + switch marker { + case 'i': // int8 — 1 data byte + if len(b) < 2 { + return 0, 0, false + } + return int(int8(b[1])), 2, true + case 'U': // uint8 — 1 data byte + if len(b) < 2 { + return 0, 0, false + } + return int(b[1]), 2, true + case 'I': // int16 BE — 2 data bytes + if len(b) < 3 { + return 0, 0, false + } + v := int16(binary.BigEndian.Uint16(b[1:3])) + return int(v), 3, true + case 'l': // int32 BE — 4 data bytes + if len(b) < 5 { + return 0, 0, false + } + v := int32(binary.BigEndian.Uint32(b[1:5])) + return int(v), 5, true + case 'L': // int64 BE — 8 data bytes + if len(b) < 9 { + return 0, 0, false + } + v := binary.BigEndian.Uint64(b[1:9]) + return int(v), 9, true + default: + return 0, 0, false + } +} diff --git a/internal/xgjson/detect_test.go b/internal/xgjson/detect_test.go new file mode 100644 index 0000000..d5ec3ab --- /dev/null +++ b/internal/xgjson/detect_test.go @@ -0,0 +1,110 @@ +package xgjson_test + +import ( + "bufio" + "bytes" + "os" + "testing" + + "github.com/dmitryikh/leaves/internal/xgjson" +) + +const testdataDir = "testdata/" + +func TestLooksLikeJSON(t *testing.T) { + t.Run("real JSON model", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.json") + if ok, err := xgjson.LooksLikeJSON(r); err != nil || !ok { + t.Errorf("LooksLikeJSON(json) = (%v, %v), want (true, nil)", ok, err) + } + }) + + t.Run("UBJ file is not JSON", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.ubj") + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(ubj) = true, want false") + } + }) + + t.Run("empty reader", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader(nil)) + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(empty) = true, want false") + } + }) + + t.Run("random bytes", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03, 0xFF})) + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(random bytes) = true, want false") + } + }) + + t.Run("reader is not consumed", func(t *testing.T) { + f, err := os.Open(testdataDir + "test_binary_logistic.json") + if err != nil { + t.Fatal(err) + } + defer f.Close() + r := bufio.NewReader(f) + xgjson.LooksLikeJSON(r) //nolint:errcheck + // Peek must not consume: a second call should return the same result. + if ok, _ := xgjson.LooksLikeJSON(r); !ok { + t.Error("second LooksLikeJSON call returned false — reader was consumed") + } + }) +} + +func TestLooksLikeUBJ(t *testing.T) { + t.Run("real UBJ model", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.ubj") + if ok, err := xgjson.LooksLikeUBJ(r); err != nil || !ok { + t.Errorf("LooksLikeUBJ(ubj) = (%v, %v), want (true, nil)", ok, err) + } + }) + + t.Run("JSON file is not UBJ", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.json") + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(json) = true, want false") + } + }) + + t.Run("empty reader", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader(nil)) + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(empty) = true, want false") + } + }) + + t.Run("random bytes", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03, 0xFF})) + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(random bytes) = true, want false") + } + }) + + t.Run("reader is not consumed", func(t *testing.T) { + f, err := os.Open(testdataDir + "test_binary_logistic.ubj") + if err != nil { + t.Fatal(err) + } + defer f.Close() + r := bufio.NewReader(f) + xgjson.LooksLikeUBJ(r) //nolint:errcheck + // Peek must not consume: a second call should return the same result. + if ok, _ := xgjson.LooksLikeUBJ(r); !ok { + t.Error("second LooksLikeUBJ call returned false — reader was consumed") + } + }) +} + +func openBuf(t *testing.T, path string) *bufio.Reader { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("open %s: %v", path, err) + } + t.Cleanup(func() { f.Close() }) + return bufio.NewReader(f) +} diff --git a/internal/xgjson/testdata/README.md b/internal/xgjson/testdata/README.md new file mode 100644 index 0000000..105be25 --- /dev/null +++ b/internal/xgjson/testdata/README.md @@ -0,0 +1,22 @@ +# XGBoost Test Model Data + +This directory contains generated XGBoost model files used by the `xgjson` package tests. + +## Files + +- `train.py` — Python script that generates the test models +- `test1.json` — XGBoost model in JSON format (XGBoost 2.x/3.x) +- `test1.ubj` — XGBoost model in Universal Binary JSON format (XGBoost 2.x/3.x) +- `test1_expected.json` — Input feature vectors with expected raw scores and probabilities + +## Regenerating + +Requires [uv](https://docs.astral.sh/uv/). + +```sh +cd internal/xgjson/testdata +go generate +``` + +This trains a small XGBoost binary classifier (10 estimators, max_depth=3) on synthetic data +with 3 features, then writes the three output files above. Commit all three output files. diff --git a/internal/xgjson/testdata/generate.go b/internal/xgjson/testdata/generate.go new file mode 100644 index 0000000..73179f5 --- /dev/null +++ b/internal/xgjson/testdata/generate.go @@ -0,0 +1,3 @@ +//go:generate uv run train.py + +package testdata diff --git a/internal/xgjson/testdata/test_binary_logistic.json b/internal/xgjson/testdata/test_binary_logistic.json new file mode 100644 index 0000000..8d20caf --- /dev/null +++ b/internal/xgjson/testdata/test_binary_logistic.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"10"},"iteration_indptr":[0,1,2,3,4,5,6,7,8,9,10],"tree_info":[0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[3.4848927E-8,-1.5204959E0,1.6623036E0,1.1777768E0,-1.6315176E0,7.10179E-1,1.8392156E0,4.33977E-1,7.083949E-2,-5.601467E-1,-1.6758814E-1,3.6204177E-1,-1.5752013E-1,4.071052E-1,5.7988805E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.1048126E2,3.3520996E1,1.4945984E1,8.6061764E-1,2.4676788E1,1.0907537E1,1.0265808E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.733439E-1,-1.8622892E0,5.279944E-1,1.3041685E0,-2.5623593E-1,1.3718903E0,7.7473426E-1,4.33977E-1,7.083949E-2,-5.601467E-1,-1.6758814E-1,3.6204177E-1,-1.5752013E-1,4.071052E-1,5.7988805E-1],"split_indices":[2,0,2,1,2,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.9996877E2,1.0448368E2,9.5485085E1,3.7494144E0,1.0073427E2,1.574754E1,7.973755E1,2.4996095E0,1.2498047E0,8.19872E1,1.8747072E1,1.1248243E1,4.499297E0,1.549758E1,6.423997E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.3946685E-3,-1.1543779E0,1.2591227E0,9.5409256E-1,-1.2454361E0,5.2215135E-1,1.4025306E0,3.727428E-1,2.3597144E-2,-4.3348977E-1,-1.21060215E-1,4.5374906E-1,2.7029645E-2,3.0377525E-1,4.4462258E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.7547772E2,2.011644E1,8.874084E0,1.0756211E0,1.5365356E1,6.8806324E0,8.410034E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.733439E-1,-1.8622892E0,5.279944E-1,-7.9844646E-2,-2.5623593E-1,-1.8622892E0,7.7473426E-1,3.727428E-1,2.3597144E-2,-4.3348977E-1,-1.21060215E-1,4.5374906E-1,2.7029645E-2,3.0377525E-1,4.4462258E-1],"split_indices":[2,0,2,2,2,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.874821E2,9.758766E1,8.989444E1,3.647652E0,9.394001E1,1.5399192E1,7.449525E1,2.428422E0,1.2192299E0,7.53628E1,1.8577206E1,3.9179015E0,1.1481291E1,1.4947725E1,5.954752E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[7.773815E-3,-9.943017E-1,9.9700433E-1,2.7761754E-1,-1.076667E0,5.286407E-1,3.7626263E-1,-3.394221E-1,3.191901E-1,2.0238668E-1,-3.8180283E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,-1,7,9,-1,-1,-1,-1,-1],"loss_changes":[1.6563177E2,1.4121422E1,9.465561E0,0E0,1.0584633E1,9.1505995E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5],"right_children":[2,4,6,-1,8,10,-1,-1,-1,-1,-1],"split_conditions":[2.0596176E-1,-1.8622892E0,7.2267E-1,2.7761754E-1,-4.3772988E-2,2.3320758E0,3.7626263E-1,-3.394221E-1,3.191901E-1,2.0238668E-1,-3.8180283E-1],"split_indices":[2,0,2,0,0,1,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.6508813E2,8.200907E1,8.307906E1,2.9528005E0,7.905627E1,3.0533514E1,5.254555E1,7.754477E1,1.5115029E0,2.8823126E1,1.7103883E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[9.2651285E-3,-1.0554368E0,6.792614E-1,2.3023954E-1,-1.1098394E0,1.0147718E-1,1.0298864E0,-5.1912617E-2,-3.520825E-1,1.7861223E-1,-3.123424E-1,1.9472034E-1,3.3861062E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0307969E2,6.3869324E0,1.786119E1,0E0,3.0589828E0,2.0146532E1,1.3401375E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.5623593E-1,-2.0155988E0,5.279944E-1,2.3023954E-1,-2.6036727E-1,1.3718903E0,7.7473426E-1,-5.1912617E-2,-3.520825E-1,1.7861223E-1,-3.123424E-1,1.9472034E-1,3.3861062E-1],"split_indices":[2,0,2,0,1,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.4249661E2,5.4814583E1,8.768203E1,1.2086432E0,5.3605938E1,3.360115E1,5.4080883E1,3.6531324E0,4.9952805E1,2.379479E1,9.806361E0,1.2701178E1,4.1379704E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[6.5199593E-3,-9.5052385E-1,5.596426E-1,2.0603801E-1,-1.0059912E0,7.9649486E-2,9.219858E-1,-3.384241E-1,-1.2668663E-1,1.2908778E-1,-2.4213417E-1,1.5698004E-1,3.0862683E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[6.568722E1,4.8187065E0,1.3602236E1,0E0,2.6854439E0,1.1204861E1,1.3651428E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.5623593E-1,-2.0155988E0,5.594684E-1,2.0603801E-1,-5.7463634E-1,1.3718903E0,7.7473426E-1,-3.384241E-1,-1.2668663E-1,1.2908778E-1,-2.4213417E-1,1.5698004E-1,3.0862683E-1],"split_indices":[2,0,2,0,2,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2208248E2,4.445151E1,7.763097E1,1.0835884E0,4.3367924E1,3.391881E1,4.371216E1,3.5094765E1,8.273158E0,2.4676153E1,9.242659E0,1.0479339E1,3.323282E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[3.5492254E-3,-5.9674454E-1,6.862838E-1,-1.0244098E0,-1.5051265E-1,3.5410693E-1,9.457186E-1,-1.983951E-2,-3.1750667E-1,-1.2765785E-1,4.3486968E-1,3.8029623E-1,1.0213314E-2,2.9135826E-1,1.1577517E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[4.385949E1,1.0713873E1,3.979826E0,8.703499E-1,1.3396497E1,6.8439E0,1.7572403E-2,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[3.263706E-1,-5.7463634E-1,7.7473426E-1,-1.534747E0,-7.9241884E-1,-1.8355771E0,1.0084084E0,-1.983951E-2,-3.1750667E-1,-1.2765785E-1,4.3486968E-1,3.8029623E-1,1.0213314E-2,2.9135826E-1,1.1577517E-1],"split_indices":[2,2,2,0,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.0501652E2,5.5949192E1,4.9067326E1,2.7907331E1,2.804186E1,2.255417E1,2.6513155E1,1.0556241E0,2.6851707E1,2.455545E1,3.4864097E0,5.08358E0,1.747059E1,2.4656202E1,1.8569524E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[9.643211E-3,-9.6156096E-1,3.2266313E-1,-1.6817693E-2,-1.0025502E0,1.896752E-2,3.062792E-1,-3.163531E-2,-3.1356356E-1,3.1077072E-1,-8.825563E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,-1,7,9,-1,-1,-1,-1,-1],"loss_changes":[2.8589937E1,7.9467964E-1,1.5057805E1,0E0,7.268925E-1,1.6408224E1,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5],"right_children":[2,4,6,-1,8,10,-1,-1,-1,-1,-1],"split_conditions":[-5.7463634E-1,-1.534747E0,-7.9241884E-1,-1.6817693E-2,-2.3389478E0,-1.8355771E0,3.062792E-1,-3.163531E-2,-3.1356356E-1,3.1077072E-1,-8.825563E-2],"split_indices":[2,0,0,0,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[9.202333E1,2.1924986E1,7.009835E1,1.052031E0,2.0872955E1,4.9567326E1,2.0531021E1,1.1051567E0,1.97678E1,1.1126637E1,3.844069E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[1.0875948E-2,-9.042081E-1,2.6562464E-1,-1.4255263E-2,-9.5357305E-1,2.432597E-2,9.072643E-1,-2.707678E-2,-3.0165362E-1,1.6966063E-1,-1.2472395E-1,2.841285E-1,9.571609E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.9578981E1,7.282772E-1,1.0144684E1,0E0,6.824322E-1,1.1845299E1,1.1762619E-1,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-5.7463634E-1,-1.534747E0,-7.9241884E-1,-1.4255263E-2,-2.3389478E0,-1.451989E0,1.0084084E0,-2.707678E-2,-3.0165362E-1,1.6966063E-1,-1.2472395E-1,2.841285E-1,9.571609E-2],"split_indices":[2,0,0,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[8.195767E1,1.7292343E1,6.466533E1,1.0490721E0,1.6243273E1,4.7747166E1,1.6918161E1,1.0771346E0,1.5166138E1,2.1285168E1,2.6462E1,1.5273769E1,1.6443915E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[8.420152E-3,-8.866207E-1,2.0744449E-1,-2.3173474E-2,-2.8409058E-1,3.5733518E-1,-6.553415E-1,3.0889487E-1,2.7870974E-2,2.424394E-1,-3.562318E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,-1,-1,7,9,-1,-1,-1,-1],"loss_changes":[1.3613471E1,6.0597706E-1,8.315063E0,0E0,0E0,9.483938E0,8.436519E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,5,5,6,6],"right_children":[2,4,6,-1,-1,8,10,-1,-1,-1,-1],"split_conditions":[-6.423459E-1,-2.3389478E0,1.8273275E0,-2.3173474E-2,-2.8409058E-1,-1.4220484E0,-2.035454E0,3.0889487E-1,2.7870974E-2,2.424394E-1,-3.562318E-1],"split_indices":[2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[7.438975E1,1.2903961E1,6.148579E1,1.053575E0,1.1850386E1,5.2884132E1,8.601657E0,1.4111651E1,3.8772484E1,2.1552174E0,6.4464393E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[3.1257947E-3,-2.370009E-1,7.368995E-1,-4.209105E-5,-8.483909E-1,8.1515753E-1,1.465431E-1,3.7057182E-1,-1.14696495E-1,2.5674397E-1,-3.5496542E-1,7.9826936E-2,2.756763E-1,-1.3072331E-1,1.8127826E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.2430462E1,7.701972E0,6.841221E-1,1.9037214E1,9.485901E0,6.1796856E-1,1.1308331E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[7.2267E-1,1.398677E0,1.8586534E0,-1.2937937E0,-2.0780978E0,6.0014886E-1,1.2276509E0,3.7057182E-1,-1.14696495E-1,2.5674397E-1,-3.5496542E-1,7.9826936E-2,2.756763E-1,-1.3072331E-1,1.8127826E-1],"split_indices":[2,1,1,0,0,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[6.855681E1,5.2163403E1,1.6393408E1,3.8313972E1,1.3849428E1,1.4138422E1,2.254987E0,8.527461E0,2.9786512E1,2.0186934E0,1.1830735E1,2.8137515E0,1.132467E1,1.0135728E0,1.2414142E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[4.9375E-1]","boost_from_average":"1","num_class":"0","num_feature":"3","num_target":"1"},"objective":{"name":"binary:logistic","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_binary_logistic.py b/internal/xgjson/testdata/test_binary_logistic.py new file mode 100644 index 0000000..dad9bd4 --- /dev/null +++ b/internal/xgjson/testdata/test_binary_logistic.py @@ -0,0 +1,74 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates binary:logistic XGBoost test model. +Run with: uv run test_binary_logistic.py +Outputs: test_binary_logistic.json, test_binary_logistic.ubj, test_binary_logistic_expected.json +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_classification +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + X, y = make_classification( + n_samples=1000, + n_features=3, + n_informative=3, + n_redundant=0, + n_clusters_per_class=1, + random_state=42, + ) + + X_train, X_test = X[:800], X[800:] + y_train = y[:800] + + model = xgb.XGBClassifier( + n_estimators=10, + max_depth=3, + objective="binary:logistic", + random_state=42, + eval_metric="logloss", + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_binary_logistic.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_binary_logistic.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + raw_preds = model.get_booster().predict(dtest, output_margin=True) + prob_preds = model.predict_proba(test_rows)[:, 1] + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + "probability": float(prob_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_binary_logistic_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f} prob={e['probability']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_binary_logistic.ubj b/internal/xgjson/testdata/test_binary_logistic.ubj new file mode 100644 index 0000000000000000000000000000000000000000..2123e1447fb1018537e71b6985abe28c71ba9cd0 GIT binary patch literal 11883 zcmd^_4^$M@8o&Wz7g%t4nVOpa@oAs8H7%qq02sAmJ9GjBqWcZw= zAvMY;hFWhHDv-BxPR49x&DpKWWwq|Iq&PW-F}fu%I!v4?L+<$xw#*FoEvuEW8u4Ru ziL%}iCW`mq{T%lfkc zY@pn*TQ1g?W^btyC{?-Xk5uq;OX2abtVmFbiD!(97}k=G$H^>@lRmQduD8WQTQ-At z4<~^atIcBld+Wi3V>7@v3l~zS#~Y&8-91y(GlzxyH&3QsPg*Sujs2dgciaQ+=GTyZ z1EuL=EOw5SW?g9xLR!j=Hff@jpX%lDRLhjc%A(44=Gh$2F2T(yJQZYS(oC6Fr_su! zITJ(Y$?fQGSUouqEqqurEB6tMYZ+71S*w|2Y@S-->fTcBA!R=)Yo#0@Wu26Rn>$FP zuOO?P=cVRNHVYQtV?L%dM;9VZ?_6jfRs|lLQ5~y*v(Jr;)xp{KM#ie)tihQy+>^rwq&h6PEnFYh@t1O-l!K()vx8A-Chdpt z4l66QZ?~CQ{Km8%8fwjdE5t_QlR!Iq8`$3Xi;x$Q2=)~A2h-wz>a3{-U++PB6WP?( zc8KQ{D}WO|lEEQIcY)eJqCnvF<}m;}V{$rPLPt8o^Q=j-F;6R){-Zal|Dg;CKPJ)k z%jxK?MVDaC_ZQHX#l`gCD}`Xol|oc{g@Z3#^@fo@9;2VTO2I{!`=HzeZ=Pv?EodCzXeXfid&xGf#1#ppnSz(O88)ep?cO);nXFg>xagL)K_zk z8ct)${xzFrMC6b+)62xsBM@?>bg?=xfz|EDZVH$#yw>X(Y8 z;*+Zjz_O^_RKf5wVt)N3pp9>)s>+rLyWgy&O08$8zfex0CZ|Qo7XQW5Uyxq~Pq*dw z_kysyLK2xK>;d_Z3w7+^s`#tu_JHuGT{?GrK}w%_(wwTtqnk3Kq!OAgCcwVjcUU5L z-=jO=UFi+@ylo<_DPcvo?7AqV4i=GoyzyJp%CjL2wd2uXHAy`UR!zEgdo}pHZ6fWp z>%r)#xG%T4F)_*OOKeuI$eFX~-gHHKW~Gh^BA_H!D=c88Z%RZ!6)Y&3gPM0D%~+qc zXq#^t+CD0WHeB6})aN{?&b|Q8zEA@`YLvfXfn9?Dc~y3+-bk>a z@6lRuTUX{sR|aHIf%uhryPD9 zmW=5KcIFr3Yo$b5Rf0qCGO2HmO6|fydakp}9<071B&7R6MBTPWB51p?2*y4|l~?n6 zyxhSCC_FvuBw6lsWPs|>AOhpDAijj>VBq^VB_gm2Jd_Pc4IjZ%dM+D=645ZUCn*vX z?t2H#x8@*?WdllxJ3+HY&!g$bcR?Ys0^WPJ0nU6#gGM~{A6bH~!HB%%xn+z*y;qhZ z^7B>!eT<8WyH+CR*2YDb4xS3E+9|^BspkdnH$o}jaW4vsF8NVI#K);3-}$n#cL+v& zi9%hEkvwAqT{!*&xM1rcY}NGrr)R^+>Gy(HT{jsyxj}@z(=ei+6z4X;UIUT|MmjRU zUxyKl6O0mq6*8dmk&Nu^!3b&^9fcx}EkK*5^aQKRHE3*hHPY@ng~C3hQIIJR`X)EP zy7U^@>(fR!a;FjbHm1uGbTf?TqtXipi}%#hVBFqVF>mlmq5t5Ysr`Rm4!Hg=Rc>98 zEo@zD5PS~0sM9Y;Q-=OlF8w`tKh#~^sW zRYO0yKN}<;r{VmBV7kFO%J5v1Z0>3&M0h2@+zH#XFZT};F=V8@5V4lqD9&iOES_tM z0Ly!=?5wG(gNJO>V#uh1rlXNp0g#%>sd+Mb^p1ZfGXDDtiix=hKSmSbC-W-cKHEk3 zZvTs5<)GK#`}MJ~_OEMcWDs0n zv#RR-S{%kkduFAM2|Pk=@zD1Arjb*@g6C+HK~_V84jt58f_u`8=$Y|}V8xCmxcyo^ znkM048(0ON%{&Z?LY;8DWS% zfpoOE@>Km0p^!_U##TJ#YPL_M)-9e9U8^}NE8DG5puDujFwhc()M0=tE5bm!e+zpf zaC1pDnCy5Q4x3$!Z5V`=e*&XdBt(CA7GEkRBCHY|y5ZhPD_BSd=!V7XA02V~7#$d( zIy5}@M#w0tLW4X&MpgwL$_Au{k6=`KE*`F)e-JH?-wswxZ-&b+O+wVq{ou+32IPPI z4CHp?!dL#JfhYbo7CKFjfp>DOvIO01QKK^iH0>5st~^i0TjMKT@5c)3oN?5Frc!Za zZ5CBDGTZR^oSjrQ3J?yMEn^#iZcrvt;opUjy$JaE-k@W38}bWib@7Sdchp2ULs zBVfSL+0ZXAmzXZ>D?A7_*RGLkawM%O!J;Ei;@8HFkPo#xZUn1{lO%}ZFD*o{b~jFS zVyNRy3k%_SoK;}^$3%~Kc2ApE=Zlxsf(7rsgq%QpE}`8F4Tex-@vrQxr1)>PFoL-sW=VmvW}&u`TSQ8PI; z^ypFkz&>d9%X4Agh;q~q{v9T5L@?#St8~bm{m|c$0~==?p}ogXgV8fyqaV0?Ih+vj zAy{5n4T{dTghO-ru;ESz;sv@xj!m3}amuPl zL3Tud0DH<4QUo5G#mG^2t<~12ur?F7c(T=QVRm&zn%X*SFoaJvSsA0l g&fkb3_Uv`3+^<$rmY2ddTOyEqNU_a6tk=^20iK=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates multi:softprob XGBoost test model (3-class). +Run with: uv run test_multiclass.py +Outputs: test_multiclass.json, test_multiclass.ubj, test_multiclass_expected.json + +NOTE on base_score: XGBoost 3.x stores a per-class base_score vector for multi-class models. +The leaves Go library (xgEnsemble) only supports a single scalar BaseScore, so it uses 0.0 +for multi-class. Expected values here are the raw tree sums (output_margin minus per-class +base_score) and the softmax of those sums, which is what the Go library actually computes. +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_classification +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def softmax(x): + e = np.exp(x - x.max(axis=1, keepdims=True)) + return e / e.sum(axis=1, keepdims=True) + + +def main(): + X, y = make_classification( + n_samples=600, + n_features=4, + n_classes=3, + n_informative=4, + n_redundant=0, + random_state=42, + ) + + X_train, X_test = X[:500], X[500:] + y_train = y[:500] + + model = xgb.XGBClassifier( + n_estimators=10, + max_depth=3, + objective="multi:softprob", + num_class=3, + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_multiclass.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_multiclass.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + # Parse per-class base_scores from the saved model. + # XGBoost 3.x stores these as "[v0,v1,...,vN]" in learner_model_param.base_score. + with open(json_path) as f: + saved = json.load(f) + bs_str = saved["learner"]["learner_model_param"]["base_score"].strip("[]") + base_scores = np.array([float(x) for x in bs_str.split(",")]) + print(f"Per-class base_scores: {base_scores.tolist()}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # output_margin=True gives per-class raw scores including per-class base_score + raw_preds = model.get_booster().predict(dtest, output_margin=True) # shape: [n, 3] + + # Go uses BaseScore=0 for multi-class (xgEnsemble doesn't support per-class base_score). + # Subtract per-class base_scores to get the pure tree sums that Go computes. + go_raw_preds = raw_preds - base_scores # broadcast over rows + go_probs = softmax(go_raw_preds) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_scores": go_raw_preds[i].tolist(), + "probabilities": go_probs[i].tolist(), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_multiclass_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" raw={[round(v,4) for v in e['raw_scores']]} prob={[round(v,4) for v in e['probabilities']]}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_multiclass.ubj b/internal/xgjson/testdata/test_multiclass.ubj new file mode 100644 index 0000000000000000000000000000000000000000..4a2c619f592a1fe589cf42c749851dc7b1628675 GIT binary patch literal 34738 zcmd^|2V4}#_rPh2l%t9Qme^v81rQ6u?v%3^?5IJ-Lk>Kp$$=t*6{Q$^uwX+pYV1+5 z!R}$hg0Xj_VnMNEHzv{l&Ba*{y#;;|N#MVq&$2tax3llPdHcQhX6BZ0-Sqxmj05Ge zuwZ%EIC{WT79Jj^@QnzUE62sqgA#waEIcAi?h`Bvk}JLI(&Kt`NBBVa*ifAxunGv1 z`6=YV;Xb}0A7EJs z5ILt5P8zgu66&{7s1=okD?);O6v2L>;b9+DF;nO%3>6g>CJJ+frNTyGr?6L4Q`8hH z*2tv@_7C}ZNlo=7`U-=8_Lo*%K_jhJrGmodzA~lUXN+7CFr3)aPq3$?MdqOfU}5VB zK0`T~ujRcEtgPNMWA&JCRFx+!VJ~=EFsfrGc_X7Y%sBH9zO(!XE@8<&dOuTjbIJok z!W3$+(A7p(AJ6*)s~ti2r);%8^|8sj2_g!9s9l{>5k;!g@hNjZxxXwTFx)3l?jP=5 zub)tlcAsjgNQ96*HDvOYezLJw7Zn~%D>gsW@PgL`HSg;^$JVYBSnq^xV9wHm8^ z`qRflGC?he)LAxvB&w>nmDH=LdNn6y zT??((N*NlcP}dt0?57~lOFKQBm)y_#Z4!CE+4AqoTH$l}kr;2hw$=@}r zpFhR0g5E{ii?kJ4=$0rzjiEX5=^rYv>EG27+!gIWW=b_ChRc;og-q=+T5T)!*_(iS zg-n%xCtnLKU7O=h(HKwOFcP1wnTwkT?uLh>G1`%ofg4NS;Ekzc(W0kY1qn2WC)#W`<*Og=?b^U4Ok-qc#m^tDykBL<1QH9F0O z)vmy18E@hTaqqdckpLh?|f&fLI7Q}%P5j~ zi93aUk1Y&eqeLHLWH9O-Jbx?{PLN!L?Kf@tMDWjww02#9v7DfmL+gr90YjcOvi?)h z5VHuNQSS?oQSV!5n?Tn-6}(X3Sn|AADu0|Apqi8e_}g2zfQQ#Pe*1wjVCE>V;<~Ch zc&Gwuc{K|)2;I`-I=55-5dev~oEU;ErQW*-;=H^b_;}(lym)pZcHEo@pICOmszt#l zq1r+mHm?)5T{9cxyzd9imnXrss)=~sn%QXcp(sI?(m_#6dcQA4k)(y$;QFdY4~MbU z>t0i}R83TAP#}sx{2*X?7^kd`R{Z4OQwh$y9(HKyL z!~s+wqEUOJF?piY+b@i0&<^R3&@px>I_38>XufD7@pH7{bTzb0`IzkMTBxwPfKyaN zL!YZJifGhARr;WQNklZEQFS5CisjoLp9&5it416U9ID!jSGjE{$^pf}q8-uDNQM}b z7(jHPmK6btZU+*>iA3t-xB}O69U^TdNx(HPEd}cbnBbcmM&pFwCD>v7RZTg4co($P zY8&c!Ck+q1HC>RPbTCqK5lwofw>yAq?q)vo)){{KFH1nTNE7by`4pAjxeV5$-wnR) zU%u?LwZLQ7jz9U{y>EN8&I{$8Zw%-9SdI{s{+QZmClf-xp`FqE`w5cR*9(bg45>=u zKp~fZ&4@-3_dC2f@*+%1&WCj-EP!KHLl_Uw!pghbi2DDB6I6IzfGH0VO~>#XnKmzj zND8Go03U&@#)hgh3*Kr{sA9URICx0A63EcDCHn+GWTD%?OFI5i%ZmU=%;m%oWGNk( zlZoe7_LN$jUjz^5XX1pZ7&GIxVgAVgpt{>0E`0R_kD0m>{?UauZbc0PsX zzN#+BQaUIqv53ZU&nAflo9Y|MKl*9D$KGw-LC>v&fuzb2mED0b*3fntS0kVSyJ)~; zCb#WC-u8G58(+UK-$CAnyAV`KP`ccp$cU=+ub~LNAG{PLUH3;W(bM3HJyqe533_mJ zS{o3Mn4;quC~$}hUtUo(Z{SHyauqQuO*P|M8)n%iz0yQg`^{qJeXj=G&1-q=V~c%!7fDTb z*OObg&HV?kZQ9@A4fh$a?Xr&uhA&eH(j6O=NW`|Z zJ#+!4oFrGO8y&MX$(4TJL``yaaEXc}SH*Qzaqy6Kr9C(_v7!+U-HI7r=td~9h{n^U zBK|SCiPZ3RKHNO&5KdTWh27p(!RwBlhg0Mo!4?-iJo-=+y!EmPY8F);TX@gHaXZq` z4+Ex9FDNAxm0WU_Za?-C8?~rEf5~LK$L@EtLCeKyz#`g9RVgEmF@AlDuNiC1Z9Z9> z4Z7yehW7lK%a8t+XI!GWRrsx-bh$y1F;(ebLlIh^`~rDZyN@dM-vZ;8n?v^vTcBa` zNwD5CQ}jvraDob7UQrYfn95tc$R;CKWYmh?+h+$z=#DB_ z%tMu5t%vWr)+RpBe3K>W`qcWmfK%M$iVl_kqq#V0$uvXcAJqTCxi}=YVFr&G%q696 z1jp7*0e{Ussd2z&^%fT8fZ|{w?L~WOh$RV7id?8=MS!B)fyBCvCDL94tflT~1lp3; z5|6Uqg8zCI149Nmqq#>myp>JopgRNcC6r*`7%^c4NCd&CM7(aX#U$-tN#j){%I{7aHJiWwX{9#T`v?4IcN$Eq-TqI4yf?D z08<{4tN32O7f7ziW{ODTA6fG76I#qZ~T z0~ZhKkAqwR_O!o`7bdKO-yiG;cHY^Bdu-niAN+9(RlFXHo5)OX+_ZF5scM;xXe{?k zPcURwi?)lArD*tZUKhD|Czj2RdRPaigG}4u#om5Jv9U*+JK>*$~uipEh_>P-3}xc*+`@v zFI+*{^7GPWBWt6~kegt$PdjP#ovqM}<-Ji_5(hS>_LJ&8H$c~FuK<_l2cZcM&I=Nh z4n|5Wxsqg58)*ZKf+6qOz>&A!dI;En=^3vLCh%!Jq%70x3709!;x5WBGnKOw_)5PF z&gefSQBby=V8om%R7{L~K!`AAMasGsTCM2M8YE~=-wm+ko_OeeZYJ?;sNd%#?09Z9 zc+=&p41m+VnJ>Z!QB^Igrk_T={lY$)0>76c4yf*@5q>XKgaa0h5n@z|4uX_Vu|m!Z z3vI6w!iaPFaP0p2h}3`^4fW$3af|Eyapb*SIH95?ZqnlvI12%nKP(!ynnH_&e79+;z^X zt|@n?DZ@5PP2h%%ufjK}lq@J+rf?#bFq61MLRAt6%Gd6rX~{f*IfMJc)sd-ieE(2v3;Lc0B;8fK@ z7+c348?HMKCnxs<>$dfRS*wf@yX-D@crja$rF2kKa*2_LNsK;!JAWRN*nO7DFz7bJ zFJI{XR^|z;xH^18_hoJW>>8DEwPio%#M1%1^*A4PP~Z(g*)oM7-El1ekOfsp96(x-V!3ykA7{d4OCw?TSDFY%m*-PqbpfYLCq|78HHneh0X0+lNMclk z1B!#?zmOO)`z+8}|9jGgY0hYA+ooV^_8w`~2`MOHLjsyt-4?94u?j!1io(WMXMk5{ z{zBp7=hC(5;Mp!lsj??UQK}Gu*6~l{c#+XI0UEJg+j^F zB=|OQGFW<6fxTDNg0G!l;Hbb==z8=RF!jcAkTkHgf4W&>5sjt0dCEYx`Os*7UY|ca z*fn|V8`s_J#bsH1R+D_@nDJ1~*d>*1TJ=02vzOAB6lD47|P9>#XJa}bm*QwS1IuB@p-;sE+A zxl+4f2~6pOA>Iw(M%z$0{nr=p{^Z(l1uG>E&QyOlk!*D>R9IcWDQ2*6*APon%ZdO+w*!gw)kvhSAXHlW+$uEBb1439{O|bDgKXGs z))J(Qyo4*)x()Z_uftb6PeK)JdV|!Q{m7?!rXWG-V5G#7D@ledXFhu|&5y6i9OeF; zQkQ-1)<xl5sHuh;HVRstVbLW#I)PT zw#a>?%4^!PCzJS@qjN@|ll)G1B+0xc4b&S9RO}4>~o$tWe znZ1c8g9XQH!4!8JIH>2X0v@hapu(0DjPRReqcU3sHYe}kBMzA9u)t08{=G-c(4ri$ zXpB&FQ2Tv~7J6K>rIsxrjMRL27X9W0q=qBS!J9)X(Xk^3@t*k?u>JU0uxg_j+Bg0M zUVU>Ns<9&5|lNJc-Zy*g+F*Dn~|od_y;R?Fq>xPanqt3u>%Z)`Rc)~ zJ(k?q!Yym>$(-vmi#c7pGtVYH6qGF|(MV;gP%$y0_PG*z!yarTZ3g^e4@38X-?PUo zHS=Cze!J%|Z`Jbx953=RScCgAC3JF90rUBSw;p9$wwpykUL03pW?Be@v{Nk$0|+%JAJ#uIs{moWGNb zt-pK}-^8)8`vI3*%>M<5MlRmIFpaitS;L5>dsN^i=Dl7n zA3UdDde-W$?3K5f?Eba({MzdyGiLAgV*MVkX6**-=3e_baMI|$Y?JERfL_Izrs@=8N;yP!RT7jMUXM$5jfScA51o62`7__v zb`W!A49YlUoWXyWElhnaFRvc2N?T7HKsrZz)IMmY>qi0f7_CQ<3bi2t3n$H10FX4b8Y{4{c%xfTRamaC+Byf&`_5krEs2 zvCIfEb!Gcm2lLK09rzdQ0(OzZR=&^aWZ=Y{P5-?m%QwA#iZvc}DPy*MHn&ES$|O|I zWWtA=vg79nPhZ)>iP+Z*nM!9*+ff`SUpt~%@M8lw+_W{UcCs5AH(&{@dp8-HcO1sZ z7GxIic^w5Rye`0$gNTMqr6aE%CsXNEi%%GkaUt^Z1~QedxUMP=9@4G^HZ+x@`=m^d z>9G0Vhp6R603_ye;t`Ee4s01ULYxZmT^eKy&**oepuP4#ok7s=yxrgitLD@2uRO$XgA^;L5XUzgIACUdWk>LMP>RZYwnOrtyEG#P{*| zX^01!nfMr{v^PNQf18aZE6?F|OM9cS-`|855AK7+Q!U}R+6vkfWv#1MlF{359D8V1 zcP_=GPR7CHeeCt4UHCtr4ahh?t}k2Fa~juVZ4Uo!#G~|?_wF&59r|V5?eUSB9QFtTyBMGpKAi*YxO{~ zn=-u0x;ZsjS;L5>`_}2Etix<4-g5YRZtlATQ2kK=sF)MXC)wWQ8a27eg7H(?o^3t2 zb9>FXSg-MH?@qh;Z~2y-=N5lK=`w{Av1FITCBpy35(moH?q6x!bPde)+YWDeo`BYy zdVtC^S-#uvr<2ZW+KWBU! zuJ);3>=KXt?DCFLDx1W~T*tw?n47;jvOi`|<-(_DvI%Ry>v8s6Y$9aDttLb5%@D-1(It| zW({v`0It1lG(TIMF{UH14|z~rS1lYxv|E>;hhm9tKe2VeBDH*pp{R;eJ832TEaZFl z5kBkl4n3aP3eNLD$f5JEcolkvG9#T(#}pTo$k%{&9hA`8#R^q?(gn;tYbVI^pF&X# z70W1C79{UOOBDlriI~D!+d#Q2ELa}q6BObn5A+F@g~@_G{zR>4m5*!iaGb~W1R0^W%3%yZhzSh!-eTu`i_O)P>U0>c$;l_CD&p=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates count:poisson XGBoost test model. +Run with: uv run test_poisson.py +Outputs: test_poisson.json, test_poisson.ubj, test_poisson_expected.json +""" + +import json +import os +import numpy as np +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + rng = np.random.default_rng(42) + X = rng.standard_normal((500, 3)) + y = rng.poisson(np.exp(X @ [0.5, -0.3, 0.8])) + + X_train, X_test = X[:400], X[400:] + y_train = y[:400] + + model = xgb.XGBRegressor( + n_estimators=10, + max_depth=3, + objective="count:poisson", + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_poisson.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_poisson.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # output_margin=True gives log-space raw scores + raw_preds = model.get_booster().predict(dtest, output_margin=True) + # predict() gives exp(margin) = count prediction + count_preds = model.predict(test_rows) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + "prediction": float(count_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_poisson_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f} pred={e['prediction']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_poisson.ubj b/internal/xgjson/testdata/test_poisson.ubj new file mode 100644 index 0000000000000000000000000000000000000000..c70dae2f4e06f02e3e1e4e0bbbfc06fa1fff5309 GIT binary patch literal 12509 zcmd^`30Msjv;F@$7E7ej*C1rNltQq)#k zZ4qs$2l}^AiwFK5kX=MXDfMh^i?u(km3mY}>shPT?yk%dTb1;Ae11y0JoD^kXXcyl zoB8I=%;Y0|sLO*(ZX$J7GignSFBCe4u~G@C3`wVFz?XiBq>f3ol3KGanWSTb;d>YO zAQfO{raZDhfYGWmP^6jBCRi*qW7nWZvc*7}r1oWc9Yaf>+)1YD zy(n{{rLl^FukxoqSiwWo!mhAnh;M?9CbiQ^%J{0#PX?);sm384Dyb^@nY}ognk9&PL@8ebnW#%OFmE^Q) z@OpTIR=80HN@~MTsHN4vN|_8+(rj;)P~DYJ?tJp#lblc9eDdLw-~A>o*p{!!Leso> zo!Ka;Z#N#G^KKc@{s#gt=el79g(>(Ko{3ogmOORRI=+D?Ur@T7_u@ZJh*ooX)$FPZ zCJ4!`AqnOPN!E}A(}d(HBq-S5Dmc(%yG4*C)EDY_@kzlaUq1Oa(JS|%`~W(|MDg-1 zW&WLRLK=DAl06<4v&WGmF@hVBvl#cf?J1@|zS=L`Oui$$yyi>8J`@l@V zb(%(e9^pqEayNpi=^S7pJ^<4?4*@>Uk?lJ38zgt%ZUIHA zpS=K7!g)~S3QcM@6cJ}TgyJ_td*C;YdSM%t)%c_jJ#(4}WM1O6g=az&FN32J#q5d0! zL==4$bDWKmqqZN?TOliUtB{rX@wSzn_Mq*rtrc)28Qr*^v0pz(Q^S54^Y>mAxj~bf z4@K)QU%)Gg2>jR$7c6gdHQooggFSopAeL!;t0@*i<^KUF5?4jtj*mx&7j#g6a{F7> zL0A3jC~9&=Wd9}b)K*0~@)tFxo+04%sw9oP_Crvab(^?eyArhPz69i1!@z_>6ihsB z1pTq4#D*{h(SvGDe3`v~=o!}odO^!WQM-bx?9J}|(Y?7{xHocBR0l_fqlV-40*X|N z(Kpx~dNsS{$70nlw~nyGhg?PSqFuxd+MqLAB%_@MMefj~=0j0pZC`v{-emkUa}3sp zI)+EQxe;4B;d^ySU+kZbqS3d!(BKna3+O4`KXyfF>ddif@@pk}Mmg0{S9Kgk_Pf_6 zaRlAwsiA1qzAKve6$e4?*+@`*#t)n)=7ag$%L(0kv+xpnC^-B=HL=Uo8WbC6;|E`P z4kV@TBL*7=Vte+yCF#=QRz>Y(bGZ>C0#N15m$|7W9nk9oGm&$N=^~2QxYq3Yj0?yF zRlR^8+e`cIA`AQ75jSXMb2>^!I}M6t(4^)=QRL}`cu;&4{;9DN%RkSkZ_QNUfy;Ja z=ej@t4@VK2;#sA>cG6v4X1QFN?@@_)+?>IV+z}w6$U#^AYbb&p-V7+JxTo>@piC2b zDO%(H-co#B;sH%tH!ot<#Kpv*!f~M9bublUfhM9*~qhchIPMU+lx)rLxo;UZ0I|MZ?tDCmu&2={+0|Cq)rl6=AYV zn1CW-RmANN55l@+xI5~qe+@-a&o=b{d;Y0y&~g%<2fwtRtNHhy08qF2G`@57kD6Ya zcM{zyJ_7R-j)HOT1mgqIn|R3eFF_B=q;VY;i?8f62rr!MCF#=AHfYL{Rwl07yI&#S zmfz!E861I>UCTw6WX2-FrzWx6uLRpJcGfC$pWBIg&dFCD`Yaw*#An;S^KYj*x1#Z^ z-f2)IhvV3MDB6(O8!H&z9hYS*Fzet6>RWHk$Li{S!KU4t?C6AWW)u~^_GM|i;t>!< zYTsF*Y#tcqsH-}LB6v>JREJCX&~2U`iVpn?T-*vY%Dqvb65C6>$@zlVtL-$h@**&P z+cO~Qn316Fb-@SRxkyysMTweob)a+FT2OMRy`)QvL(zh@HS8PuJE;1l!JOvUUDed( z`%!j67E=4$BE)sQm3_VAWp?S$#j5%BHl*{tAT+FvO11p-J;ZqHl4P{gph)^kX+9J! z7(EIPh+T%Qs(%Mlab>Nfxd0=bXEBtv8Lmd3IE12Qd z55H953f4!hBbNJ@gEMqLB04&P7+2y0+PpWDs60GN(xt_rsP~);T$g3vAz68oxR{fh z(brDuRH{w!sPgC-WaEY%?3`|gk$$_XQTOOtTXu9RBAYk@S^i>NN$)h1WVF+u$Qzo} zd??cB1F$y_6ktVF%P>oLC0f~QBlg;vb=V7$W1cu0_)Mt0lcGrVYyFSnvGtU1OZLh> zE-#I^)EWMQ=BTSWh9aS(A%g7Bj2=4xu;-s1ib~|5==mzm#MAlU!{jk|$u1KgUqphX zE3)uK(F|Doj16ZrgK+m=%fUyd-Uf3D7!bEBf*^NImvs4Cp-2k(mI)E^QuL$eg8*|W zw@pK|{UamJ6=I+^t+!Yk1$f7X_wF0ncR*Miv^!9YpVuZ@Ey-G48XsP6Y%KkUz;wMy zM>hthI~?k%5QExvRwKzshVWsWPyugCLJ}Y2Nu`leT<60^`{^yIW+o=ZLeaEE+FN#P z_^h>(Myn9^$_K?i9H8o-teYVOs4_Y&{|^EwQrkXZz57LrS5#C+hNOp_H(rp^81wGV LN7K7eE*ZZArc^Gn literal 0 HcmV?d00001 diff --git a/internal/xgjson/testdata/test_poisson_expected.json b/internal/xgjson/testdata/test_poisson_expected.json new file mode 100644 index 0000000..6a23356 --- /dev/null +++ b/internal/xgjson/testdata/test_poisson_expected.json @@ -0,0 +1,47 @@ +[ + { + "features": [ + -1.322541133291679, + -0.4861941306990469, + 0.42022670947458574 + ], + "raw_score": -0.01655573584139347, + "prediction": 0.9835805296897888 + }, + { + "features": [ + -0.10239670661109637, + -0.6505636204055921, + -0.674212461724509 + ], + "raw_score": -0.48458048701286316, + "prediction": 0.6159555315971375 + }, + { + "features": [ + -0.7123370644127087, + -0.8795096286589426, + 2.2816328772760617 + ], + "raw_score": 0.8761682510375977, + "prediction": 2.401679515838623 + }, + { + "features": [ + 0.2975110711855549, + 0.8867590544047386, + -0.4890774681189237 + ], + "raw_score": -0.18565593659877777, + "prediction": 0.8305593132972717 + }, + { + "features": [ + -0.18595667718604886, + -0.713553981370355, + -2.651708334919202 + ], + "raw_score": -0.48458048701286316, + "prediction": 0.6159555315971375 + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/test_regression.json b/internal/xgjson/testdata/test_regression.json new file mode 100644 index 0000000..825add4 --- /dev/null +++ b/internal/xgjson/testdata/test_regression.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"10"},"iteration_indptr":[0,1,2,3,4,5,6,7,8,9,10],"tree_info":[0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[-3.4900674E-7,-2.1703766E1,3.888591E1,-5.660842E1,1.1113851E0,1.8131002E1,6.76406E1,-2.410375E1,-1.3382771E1,-2.6366928E0,1.2017798E1,-1.8681777E0,9.061715E0,1.3893869E1,2.3894E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.3927622E5,2.055225E5,8.471297E4,2.5468469E4,6.0891074E4,2.5188213E4,1.1675E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[3.3198032E-1,-2.0812225E-1,1.9204912E-1,-7.2713715E-1,1.2897527E0,-7.728777E-1,6.982233E-1,-2.410375E1,-1.3382771E1,-2.6366928E0,1.2017798E1,-1.8681777E0,9.061715E0,1.3893869E1,2.3894E1],"split_indices":[1,0,0,1,0,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.57E2,1.43E2,1.01E2,1.56E2,8.4E1,5.9E1,3.2E1,6.9E1,1.25E2,3.1E1,2.8E1,5.6E1,2.3E1,3.6E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[4.208488E-2,-3.1665228E1,1.6616203E1,-4.209725E1,-9.3845826E-1,-1.0655563E1,3.008117E1,-1.8151125E1,-9.645155E0,-1.0015606E1,2.8573256E0,-7.5468655E0,3.6063957E0,6.482903E0,1.6121021E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.1125997E5,4.4195344E4,9.7264875E4,1.69125E4,1.2539539E4,2.9239414E4,3.4309188E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.357847E-1,6.725737E-1,-2.5497723E-1,-6.095122E-1,-1.3044695E0,6.4537597E-1,8.700677E-1,-1.8151125E1,-9.645155E0,-1.0015606E1,2.8573256E0,-7.5468655E0,3.6063957E0,6.482903E0,1.6121021E1],"split_indices":[1,0,0,0,1,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.37E2,2.63E2,1.02E2,3.5E1,8.7E1,1.76E2,3.4E1,6.8E1,8E0,2.7E1,5.3E1,3.4E1,1.31E2,4.5E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.1167165E-2,-2.2418953E1,1.1702036E1,-5.0537853E1,-1.5123926E1,-1.5157427E0,2.8579166E1,-1.5612584E1,-1.4890947E0,-7.604434E0,1.3045177E-1,-1.0831369E1,9.6371937E-1,5.392886E0,1.2125985E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0551161E5,2.754336E4,5.893595E4,1.6610234E3,1.7668422E4,2.4483562E4,1.3836781E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.357847E-1,-1.3719012E0,1.9204912E-1,1.2897527E0,2.5049284E-1,-1.5255252E0,8.004095E-1,-1.5612584E1,-1.4890947E0,-7.604434E0,1.3045177E-1,-1.0831369E1,9.6371937E-1,5.392886E0,1.2125985E1],"split_indices":[1,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.37E2,2.63E2,2.7E1,1.1E2,1.48E2,1.15E2,2.6E1,1E0,6.6E1,4.4E1,1.7E1,1.31E2,6.2E1,5.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-5.002281E-2,-7.8147173E0,2.0562166E1,-2.491751E1,-3.1489772E-1,1.1856391E0,2.5955784E1,-9.092211E0,4.703793E-1,-4.403968E0,1.952447E0,-5.087349E0,2.3423016E0,5.6828566E0,1.2984971E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[6.4339684E4,3.7446367E4,1.1465273E4,1.2752523E4,2.009022E4,3.1360671E3,9.630484E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.5979043E-1,-4.6341768E-1,-7.539646E-1,9.1786194E-1,-6.414816E-1,-5.100164E-1,1.3993554E0,-9.092211E0,4.703793E-1,-4.403968E0,1.952447E0,-5.087349E0,2.3423016E0,5.6828566E0,1.2984971E1],"split_indices":[1,0,0,3,1,3,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.91E2,1.09E2,8.8E1,2.03E2,2.4E1,8.5E1,7.3E1,1.5E1,6.5E1,1.38E2,6E0,1.8E1,6.2E1,2.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-2.2922823E-2,-8.812916E0,9.98283E0,-1.2287054E1,2.057175E1,3.8232608E0,2.4433542E1,-9.349065E0,-2.6622648E0,-7.8267634E-1,7.2537894E0,2.2643733E-1,8.955402E0,5.292737E0,1.1337043E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.535613E4,2.209925E4,1.6640496E4,1.2092553E4,1.9864639E3,1.0599696E4,4.4128164E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.042011E-1,1.5033983E0,5.6091946E-1,-9.983854E-1,-1.2239403E0,1.6777008E0,1.1581109E0,-9.349065E0,-2.6622648E0,-7.8267634E-1,7.2537894E0,2.2643733E-1,8.955402E0,5.292737E0,1.1337043E1],"split_indices":[1,0,0,0,2,1,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.13E2,1.87E2,1.91E2,2.2E1,1.32E2,5.5E1,2.8E1,1.63E2,3E0,1.9E1,1.19E2,1.3E1,3.8E1,1.7E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.450642E-2,-7.6010866E0,7.9013066E0,-2.0408184E1,-2.7366667E0,-1.424733E1,1.0195322E1,-1.4414116E1,-5.0269175E0,-2.0192778E0,4.2070208E0,-5.7176957E0,-9.828203E-1,-1.6186364E0,4.080827E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.41472E4,1.2653163E4,1.0205502E4,4.8470625E3,1.0069055E4,9.405698E2,9.635811E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.9218433E-2,-5.449191E-1,-1.2608839E0,-2.1356742E0,8.696059E-1,1.943843E-1,-8.254972E-1,-1.4414116E1,-5.0269175E0,-2.0192778E0,4.2070208E0,-5.7176957E0,-9.828203E-1,-1.6186364E0,4.080827E0],"split_indices":[3,1,0,1,0,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.03E2,1.97E2,5.5E1,1.48E2,1.8E1,1.79E2,5E0,5E1,1.2E2,2.8E1,1.2E1,6E0,3.2E1,1.47E2],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-2.9440423E-2,-5.389619E0,6.9236445E0,-7.6967645E0,1.5408404E1,-1.2216392E1,8.663874E0,-4.9256887E0,-1.1134292E0,-1.1573362E0,5.547367E0,-5.031951E0,1.3524102E0,1.9145297E0,6.08035E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.4982496E4,1.10109795E4,5.93475E3,7.065041E3,1.4374014E3,1.2187563E3,4.134176E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.504189E-1,1.3902075E0,-1.2515395E0,-5.100164E-1,-1.51937E0,1.3071427E0,9.633761E-1,-4.9256887E0,-1.1134292E0,-1.1573362E0,5.547367E0,-5.031951E0,1.3524102E0,1.9145297E0,6.08035E0],"split_indices":[0,1,1,3,0,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.26E2,1.74E2,2.04E2,2.2E1,1.4E1,1.6E2,6.3E1,1.41E2,3E0,1.9E1,1.1E1,3E0,1.35E2,2.5E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.395906E-2,-2.2662601E1,1.1718702E0,-2.7021698E1,-7.9877605E0,-1.3520322E0,1.1025952E1,-9.467001E0,-4.487939E0,-3.1854148E0,7.7455217E-1,-1.2673126E0,2.162385E0,1.5491321E0,4.4347887E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0796384E4,1.0635391E3,9.515531E3,3.4842383E2,1.9422064E2,7.529547E3,1.6414043E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-1.5566292E0,5.8392817E-1,9.1786194E-1,-1.7131345E0,1.521316E0,6.725737E-1,-3.426876E-1,-9.467001E0,-4.487939E0,-3.1854148E0,7.7455217E-1,-1.2673126E0,2.162385E0,1.5491321E0,4.4347887E0],"split_indices":[1,0,3,1,2,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.9E1,3.81E2,1.4E1,5E0,3.04E2,7.7E1,9E0,5E0,4E0,1E0,2.28E2,7.6E1,3.1E1,4.6E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-3.3429414E-2,-2.614427E0,7.8663254E0,-1.4629532E1,-7.3140913E-1,-1.6345493E1,9.057828E0,-6.7656684E0,-3.291841E0,-1.1315507E0,1.3278493E0,-5.9210744E0,-4.1697222E-1,1.190232E0,3.8547845E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[8.196654E3,6.8445713E3,3.0040518E3,9.4035547E2,4.140856E3,2.261709E2,1.7825186E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[6.454842E-1,-1.0676204E0,-1.6685841E0,-8.0227727E-1,3.8240975E-1,1.0666747E0,-1.0594835E-1,-6.7656684E0,-3.291841E0,-1.1315507E0,1.3278493E0,-5.9210744E0,-4.1697222E-1,1.190232E0,3.8547845E0],"split_indices":[1,0,3,3,3,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,3.02E2,9.8E1,4E1,2.62E2,4E0,9.4E1,1.1E1,2.9E1,1.65E2,9.7E1,3E0,1E0,4.1E1,5.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-5.1216763E-2,-9.203662E-1,1.65521E1,-4.6210365E0,2.68967E0,1.094314E0,1.9215572E1,-4.8870697E0,-9.8923075E-1,-6.7744327E-1,1.7080133E0,-1.0934024E0,1.1666605E0,2.6789975E0,6.5080333E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.8019707E3,5.1157793E3,8.023999E2,2.8654568E3,2.8960432E3,6.71471E1,2.3954541E2,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.6420152E0,4.886007E-2,-1.279577E0,-1.5625459E0,-2.8865865E-1,-1.0372461E0,-5.642476E-1,-4.8870697E0,-9.8923075E-1,-6.7744327E-1,1.7080133E0,-1.0934024E0,1.1666605E0,2.6789975E0,6.5080333E0],"split_indices":[0,1,2,1,2,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,3.81E2,1.9E1,1.88E2,1.93E2,3E0,1.6E1,1.8E1,1.7E2,7.3E1,1.2E2,1E0,2E0,4E0,1.2E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[3.5165477E0]","boost_from_average":"1","num_class":"0","num_feature":"4","num_target":"1"},"objective":{"name":"reg:squarederror","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_regression.py b/internal/xgjson/testdata/test_regression.py new file mode 100644 index 0000000..5a4629f --- /dev/null +++ b/internal/xgjson/testdata/test_regression.py @@ -0,0 +1,65 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates reg:squarederror XGBoost test model. +Run with: uv run test_regression.py +Outputs: test_regression.json, test_regression.ubj, test_regression_expected.json +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_regression +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + X, y = make_regression(n_samples=500, n_features=4, random_state=42) + + X_train, X_test = X[:400], X[400:] + y_train = y[:400] + + model = xgb.XGBRegressor( + n_estimators=10, + max_depth=3, + objective="reg:squarederror", + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_regression.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_regression.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # identity link: output_margin == predict() + raw_preds = model.get_booster().predict(dtest, output_margin=True) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_regression_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_regression.ubj b/internal/xgjson/testdata/test_regression.ubj new file mode 100644 index 0000000000000000000000000000000000000000..59db117971b14cf809ccca33a1a215b8affee1f2 GIT binary patch literal 12498 zcmds-2~-qE8i0pOaDXs^8rQfHgSb9}1`p5&I^8`-58iil*XTG$gN?%s_KbKSVI({x z>Z0OpJW%76c*Lk#-5}i!uA;_ZHfr#U@2xSqYStt1j_zLr)eHu4_C59yrsn;xs;j!I z{;K->o4>pN%t5lM6lh@#oR#4+afRIAaB!?~s)Mm-X5q@vWX9l_$}xJYA&s$*i^BCX zT#!z1%t){E1VsvGFtd!+p*PxWc8A-50ckcfW0A)9HyIpu>Bx_sj;kTJ6%b!c&Y6j$!mJ4fN>-&X6XJ+`>9F%~fTwFcv*JZ6uzo zPj_5uj!r^tcB#>~d#u#If`y419IVZ%XRYRR2UjzdU{=QZu)eGx>(2(T&DcOTh;7cw zrH);CvDRc;%_VZUVQzoHxGE3_ew)abYdvVA5VX8Og z#grOt&4>c(616B~plXipS5Y{XDZbnOl+c-9Op!@H^8bnCD8pm;O?W~$D9VDqL*_nHHc6KJ}u8y-p%uKRjs>PwVFv*T_ z9Y;(3gjHKT&hRXJSu-n*5sI$TnCE z+3nE1!J2~9ciWGoP21)qCG9QJZF0<_4?n!EOFJJyA1Y6y<1H{i{R#IT2v+>!4H1?N zs<~Z`rhvq^TF`Wm_*V;>8j@y6aL2E4$beqDET}JZJ-Q|cL=F-H^4cp_&krMdl$3dDs-aaQz@5mNYTF^+ zSJ2m$(Bk0aUdDB`ay$WfRrd&ho|s&)9ndj}vD;Y#*qECYsLN$CwbLn6gI20zP;LkL zRuBsajI5dq@>nLL9)etk@D&%7TK!XJ`ik|Gu312AE$gr204wjMC@T5%#v>x8>#>rY`4jne zLC^Tj56=_Pz274y6qkBxs-B|=tuI0lmO9qErV6iogQ5uj;M5;DWvZ_LF(Dv3;TUW{ z1(a9i1A1V1suIehs_NA6PonD8gi%x+e}9&^=P&+LV&dG=0iDXkDNk1MrGDqBD;wSt zXPk)UOIB?a?v2S)1}yC)b{#x`$oOIyxxFxtQu#NRHlW6zarYl2KD)ayzi+R|7f-3t zP^9}h;%oZ5Zu@n~z3*u%iUfLYP$kXlBI&U`BWp~*?h;nsi&4byS~V4UMaUykI?0uM znT90>{kYLnQ++jx5PJUSgjfFmfg%@fE&zsF0&1=Teoo3{gcHiSAb=(cu;jwfbT|eq zNxC#Oujt^QEBl4V;a#YYvxkdE#)OruPoRnAWzDFHO`O?V@^y;6dXf-|&bn6+n zUv-CI<-HU|MA(qDl3UtVJmvQ-k@`r+ckC0z`!2K*d){p6sj0piMR<|X(<;Ie&T)O6 zIbQh&MbRRsOmz(D2kPM&n1p(2AshpW&_kM}OH)Hp@t5VB z#898TRQm(lMa}ZV)W?N)NyleHsIqB4h^?0XL+E0h&nJzaNnUFotW5bnmI~`XirhA# zEfr^;BTd?CP~?xN)OaY8x9dP}-|VDwWY6fncjGk=hg8zhYmaI_ybAw*^(U;nm!hcO z=@b(o*v1iID<<>V%VPP}pF|L6O0E)1`abs5R6R!#T59z2%OfuduY7}|Xf&WE5rhE5 zARNJKIMsVlz6|9kKoAM#6;SsR2&J_O1LrVO;~wvMG?0kkPpuOhF@>qBn0jJ z7cpV&`$}ENPJZObqn?^-{U~zxfG~nkj>|7yeZ2AwiX!lY)Y>ZipocT$`KWsWf{LIn z3FJ1^T?QecPJm+>v_~{y6cyj;8Z2y|R!Uwh$QJGo{DI2I|5j9IjG^S~5`;DDlzc(% zS$;{>3UYUDkgz_WgxK-M5#pDyp=4=cg0ul%;}uEYiW(0^segD%#{}Qd9>|bsFO-kc z?AczaSuwVkc6RqkU~BazydPeQB4yjaRATqt5qwb(4n3Ihr8ECVY#HVtGS+?Lsj1eF zBDA^~JqSxEdE^=4m2XfK838TbK%jHv6;TtR90|yo1$ijohlD!tim2GCvKqeSQsqq< zMdrQp#C?N(iKWjg#mbm;;rU0!{Ejd)6`ML(e0};EVZoXn!oZ7X2)chLzoGYFQf|&A z4)?!L{`LAXX#>0lMS*xqjfbMnyFb+A-XZChJ0sOQ<4Luu&n)eYK@&A6V&Gq}{)CnH zQWWt=h$5oZtl_TrMCWzSh|<5$a(z+T!JyZWMqXI3}G{udmf+s7FY82zj3-jG}`c0JvCL&QG{^xB9dOZ{CMRX5=H9efR;lbk$|BOs?eh<2IPPjL|ucrXlN^~ z@`zkt)H=d(lm3D>y!=ys@|yNkZt5Ok@%Q;;OjH3eKRlCKdEk9vwed2arA*+vT5b~O zin{UheqK#19sL&hw&Qd1Bvm18fY*3MQhY+Ap-8)L=ppU8OMNv-D_3ZHzCKhNC)=&c z8#r8Tn^*MGJE6J(?}xXd=$~B+`XzP9<~JVg#+UodE?IoM1)qB;lUH#0o|)INl)D1k4O|N2kR}76oLGV8He)Kpd9oRd6x&&Q z;tGyQNw80u3c<)`hU07;zVbC_(7Pg_Ye%Ol>?VVS(Wl$&)uHO{!)8eX2f}Ra5FK0- N=%S+(!Po!* literal 0 HcmV?d00001 diff --git a/internal/xgjson/testdata/test_regression_expected.json b/internal/xgjson/testdata/test_regression_expected.json new file mode 100644 index 0000000..2a50522 --- /dev/null +++ b/internal/xgjson/testdata/test_regression_expected.json @@ -0,0 +1,47 @@ +[ + { + "features": [ + -2.123895724309807, + -0.8397218421807761, + -0.5993926454440222, + -0.525755021680761 + ], + "raw_score": -80.28502655029297 + }, + { + "features": [ + 0.7326400772155792, + 0.7276295436369798, + 0.05194588580729943, + -0.08071658010858232 + ], + "raw_score": 57.882362365722656 + }, + { + "features": [ + 2.4457519796168263, + 0.2799686263198203, + -1.1254890472983765, + 0.1292211819752275 + ], + "raw_score": 60.96694564819336 + }, + { + "features": [ + -0.3653215513121087, + 2.013387247526623, + 0.13653533108273744, + 0.1846803058649084 + ], + "raw_score": 53.012428283691406 + }, + { + "features": [ + 0.19652116970147013, + 0.6427227598675439, + 1.3291525301324314, + 0.7090037575885123 + ], + "raw_score": 42.95940399169922 + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/train.py b/internal/xgjson/testdata/train.py new file mode 100644 index 0000000..fc8cecc --- /dev/null +++ b/internal/xgjson/testdata/train.py @@ -0,0 +1,21 @@ +# /// script +# dependencies = [] +# /// +""" +Coordinator: runs all per-case training scripts. +Run with: uv run train.py +""" + +import subprocess +import pathlib + +here = pathlib.Path(__file__).parent + +for s in [ + "test_binary_logistic.py", + "test_regression.py", + "test_multiclass.py", + "test_poisson.py", +]: + print(f"\n=== Running {s} ===") + subprocess.run(["uv", "run", str(here / s)], check=True) diff --git a/internal/xgjson/xgjson_io.go b/internal/xgjson/xgjson_io.go new file mode 100644 index 0000000..df3f5cc --- /dev/null +++ b/internal/xgjson/xgjson_io.go @@ -0,0 +1,468 @@ +// Package xgjson reads XGBoost models saved in JSON (.json) and Universal Binary JSON (.ubj) formats. +// These are the modern formats produced by XGBoost 2.x and 3.x. +// +// Use ReadModelJSON for .json files and ReadModelUBJ for .ubj files. +// Both return *ModelJSON with the same logical structure. +package xgjson + +import ( + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/dmitryikh/leaves/internal/ubjdecode" +) + +// ModelJSON is the top-level XGBoost model structure. +type ModelJSON struct { + Learner LearnerJSON + Version []int32 +} + +// LearnerJSON holds the learner sub-tree of the model. +type LearnerJSON struct { + FeatureNames []string + GradientBooster GradientBoosterJSON + LearnerModelParam LearnerModelParamJSON + Objective ObjectiveJSON +} + +// LearnerModelParamJSON holds global model parameters. +// Numeric fields are stored as strings (or bracketed strings like "[5.2E-1]") by XGBoost. +type LearnerModelParamJSON struct { + BaseScore string // may be "[value]" in XGBoost 2.x + NumClass string + NumFeature string +} + +// ObjectiveJSON holds the objective function name. +type ObjectiveJSON struct { + Name string +} + +// GradientBoosterJSON wraps the booster model and its name. +type GradientBoosterJSON struct { + Model GBTreeModelDataJSON + Name string // "gbtree" or "dart" +} + +// GBTreeModelDataJSON holds the tree ensemble data. +type GBTreeModelDataJSON struct { + GBTreeModelParam GBTreeModelParamJSON + TreeInfo []int32 + Trees []TreeJSON +} + +// GBTreeModelParamJSON holds ensemble-level parameters. +type GBTreeModelParamJSON struct { + NumParallelTree string + NumTrees string +} + +// TreeJSON holds a single decision tree's data. +type TreeJSON struct { + ID int32 + BaseWeights []float64 + DefaultLeft []int32 + LeftChildren []int32 + RightChildren []int32 + SplitConditions []float64 + SplitIndices []int32 + SplitType []int32 + NumNodes string // from tree_param.num_nodes +} + +// --------------------------------------------------------------------------- +// JSON format (encoding/json with struct tags) +// --------------------------------------------------------------------------- + +// jsonModel is the raw JSON-tagged equivalent used only for unmarshaling. +type jsonModel struct { + Learner struct { + FeatureNames []string `json:"feature_names"` + GradientBooster struct { + Model struct { + GBTreeModelParam struct { + NumParallelTree string `json:"num_parallel_tree"` + NumTrees string `json:"num_trees"` + } `json:"gbtree_model_param"` + TreeInfo []int32 `json:"tree_info"` + Trees []struct { + ID int32 `json:"id"` + BaseWeights []float64 `json:"base_weights"` + DefaultLeft []int32 `json:"default_left"` + LeftChildren []int32 `json:"left_children"` + RightChildren []int32 `json:"right_children"` + SplitConditions []float64 `json:"split_conditions"` + SplitIndices []int32 `json:"split_indices"` + SplitType []int32 `json:"split_type"` + TreeParam struct { + NumNodes string `json:"num_nodes"` + } `json:"tree_param"` + } `json:"trees"` + } `json:"model"` + Name string `json:"name"` + } `json:"gradient_booster"` + LearnerModelParam struct { + BaseScore string `json:"base_score"` + NumClass string `json:"num_class"` + NumFeature string `json:"num_feature"` + } `json:"learner_model_param"` + Objective struct { + Name string `json:"name"` + } `json:"objective"` + } `json:"learner"` + Version []int32 `json:"version"` +} + +// ReadModelJSON reads an XGBoost model from a JSON reader. +func ReadModelJSON(r io.Reader) (*ModelJSON, error) { + var raw jsonModel + if err := json.NewDecoder(r).Decode(&raw); err != nil { + return nil, fmt.Errorf("xgjson: JSON decode: %w", err) + } + return jsonModelToModelJSON(&raw), nil +} + +func jsonModelToModelJSON(raw *jsonModel) *ModelJSON { + m := &ModelJSON{ + Version: raw.Version, + } + rl := &raw.Learner + m.Learner.FeatureNames = rl.FeatureNames + m.Learner.LearnerModelParam = LearnerModelParamJSON{ + BaseScore: rl.LearnerModelParam.BaseScore, + NumClass: rl.LearnerModelParam.NumClass, + NumFeature: rl.LearnerModelParam.NumFeature, + } + m.Learner.Objective = ObjectiveJSON{Name: rl.Objective.Name} + + gb := &rl.GradientBooster + m.Learner.GradientBooster = GradientBoosterJSON{ + Name: gb.Name, + Model: GBTreeModelDataJSON{ + GBTreeModelParam: GBTreeModelParamJSON{ + NumParallelTree: gb.Model.GBTreeModelParam.NumParallelTree, + NumTrees: gb.Model.GBTreeModelParam.NumTrees, + }, + TreeInfo: gb.Model.TreeInfo, + }, + } + + for _, t := range gb.Model.Trees { + m.Learner.GradientBooster.Model.Trees = append( + m.Learner.GradientBooster.Model.Trees, + TreeJSON{ + ID: t.ID, + BaseWeights: t.BaseWeights, + DefaultLeft: t.DefaultLeft, + LeftChildren: t.LeftChildren, + RightChildren: t.RightChildren, + SplitConditions: t.SplitConditions, + SplitIndices: t.SplitIndices, + SplitType: t.SplitType, + NumNodes: t.TreeParam.NumNodes, + }, + ) + } + return m +} + +// --------------------------------------------------------------------------- +// UBJ format — walk map[string]interface{} from ubjdecode +// --------------------------------------------------------------------------- + +// ReadModelUBJ reads an XGBoost model from a UBJ reader. +// Uses ubjdecode internally; produces the same ModelJSON struct. +func ReadModelUBJ(r io.Reader) (*ModelJSON, error) { + raw, err := ubjdecode.DecodeValue(r) + if err != nil { + return nil, fmt.Errorf("xgjson: UBJ decode: %w", err) + } + root, ok := raw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ root is not an object") + } + return ubjMapToModelJSON(root) +} + +// --------------------------------------------------------------------------- +// UBJ map-walking helpers +// --------------------------------------------------------------------------- + +func ubjMapToModelJSON(root map[string]interface{}) (*ModelJSON, error) { + m := &ModelJSON{} + + if v, ok := root["version"]; ok { + m.Version = toInt32Slice(v) + } + + learnerRaw, ok := root["learner"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ missing 'learner' key") + } + learnerMap, ok := learnerRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'learner' is not an object") + } + + // feature_names + if v, ok := learnerMap["feature_names"]; ok { + m.Learner.FeatureNames = toStringSlice(v) + } + + // learner_model_param + if v, ok := learnerMap["learner_model_param"]; ok { + if mp, ok := v.(map[string]interface{}); ok { + m.Learner.LearnerModelParam = LearnerModelParamJSON{ + BaseScore: mapString(mp, "base_score"), + NumClass: mapString(mp, "num_class"), + NumFeature: mapString(mp, "num_feature"), + } + } + } + + // objective + if v, ok := learnerMap["objective"]; ok { + if om, ok := v.(map[string]interface{}); ok { + m.Learner.Objective = ObjectiveJSON{Name: mapString(om, "name")} + } + } + + // gradient_booster + if v, ok := learnerMap["gradient_booster"]; ok { + gbMap, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'gradient_booster' is not an object") + } + gb, err := ubjGradientBooster(gbMap) + if err != nil { + return nil, err + } + m.Learner.GradientBooster = *gb + } + + return m, nil +} + +func ubjGradientBooster(gbMap map[string]interface{}) (*GradientBoosterJSON, error) { + gb := &GradientBoosterJSON{ + Name: mapString(gbMap, "name"), + } + + modelRaw, ok := gbMap["model"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ gradient_booster missing 'model'") + } + modelMap, ok := modelRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ gradient_booster 'model' is not an object") + } + + // gbtree_model_param + if v, ok := modelMap["gbtree_model_param"]; ok { + if pm, ok := v.(map[string]interface{}); ok { + gb.Model.GBTreeModelParam = GBTreeModelParamJSON{ + NumParallelTree: mapString(pm, "num_parallel_tree"), + NumTrees: mapString(pm, "num_trees"), + } + } + } + + // tree_info + if v, ok := modelMap["tree_info"]; ok { + gb.Model.TreeInfo = toInt32Slice(v) + } + + // trees + treesRaw, ok := modelMap["trees"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ model missing 'trees'") + } + treesSlice, ok := treesRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'trees' is not an array") + } + for i, tRaw := range treesSlice { + tMap, ok := tRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ tree %d is not an object", i) + } + tree, err := ubjTree(tMap) + if err != nil { + return nil, fmt.Errorf("xgjson: UBJ tree %d: %w", i, err) + } + gb.Model.Trees = append(gb.Model.Trees, *tree) + } + + return gb, nil +} + +func ubjTree(tMap map[string]interface{}) (*TreeJSON, error) { + t := &TreeJSON{} + + if v, ok := tMap["id"]; ok { + t.ID = int32(toInt64(v)) + } + t.BaseWeights = toFloat64Slice(tMap["base_weights"]) + t.DefaultLeft = toInt32Slice(tMap["default_left"]) + t.LeftChildren = toInt32Slice(tMap["left_children"]) + t.RightChildren = toInt32Slice(tMap["right_children"]) + t.SplitConditions = toFloat64Slice(tMap["split_conditions"]) + t.SplitIndices = toInt32Slice(tMap["split_indices"]) + t.SplitType = toInt32Slice(tMap["split_type"]) + + if v, ok := tMap["tree_param"]; ok { + if pm, ok := v.(map[string]interface{}); ok { + t.NumNodes = mapString(pm, "num_nodes") + } + } + return t, nil +} + +// --------------------------------------------------------------------------- +// Type coercion helpers +// --------------------------------------------------------------------------- + +func mapString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + return fmt.Sprintf("%v", v) + } + return "" +} + +func toInt64(v interface{}) int64 { + switch x := v.(type) { + case int64: + return x + case int32: + return int64(x) + case float64: + return int64(x) + case float32: + return int64(x) + default: + return 0 + } +} + +// toInt32Slice coerces any of the UBJ typed array forms to []int32. +func toInt32Slice(v interface{}) []int32 { + if v == nil { + return nil + } + switch x := v.(type) { + case []int32: + return x + case []int64: + out := make([]int32, len(x)) + for i, vv := range x { + out[i] = int32(vv) + } + return out + case []interface{}: + out := make([]int32, len(x)) + for i, vv := range x { + out[i] = int32(toInt64(vv)) + } + return out + default: + return nil + } +} + +// toFloat64Slice coerces any of the UBJ typed array forms to []float64. +func toFloat64Slice(v interface{}) []float64 { + if v == nil { + return nil + } + switch x := v.(type) { + case []float64: + return x + case []float32: + out := make([]float64, len(x)) + for i, vv := range x { + out[i] = float64(vv) + } + return out + case []interface{}: + out := make([]float64, len(x)) + for i, vv := range x { + switch f := vv.(type) { + case float64: + out[i] = f + case float32: + out[i] = float64(f) + case int64: + out[i] = float64(f) + } + } + return out + default: + return nil + } +} + +func toStringSlice(v interface{}) []string { + if v == nil { + return nil + } + switch x := v.(type) { + case []string: + return x + case []interface{}: + out := make([]string, len(x)) + for i, vv := range x { + if s, ok := vv.(string); ok { + out[i] = s + } else { + out[i] = fmt.Sprintf("%v", vv) + } + } + return out + default: + return nil + } +} + +// ParseBaseScore parses XGBoost's base_score string, which may be wrapped in +// brackets like "[5.2313304E-1]" (XGBoost 2.x) or a plain float string. +// For multi-class models (XGBoost 3.x) with comma-separated per-class values, +// use ParseBaseScoreMulti instead. +func ParseBaseScore(s string) (float64, error) { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + s = strings.TrimSpace(s) + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, fmt.Errorf("xgjson: parse base_score %q: %w", s, err) + } + return v, nil +} + +// ParseBaseScoreMulti parses XGBoost's base_score string, returning one float64 +// per output group. Handles both single-value (all groups share same score) and +// comma-separated per-group strings like "[1.4E-2,-2.2E-2,8.1E-3]" from XGBoost 3.x +// multi-class models. +func ParseBaseScoreMulti(s string) ([]float64, error) { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + s = strings.TrimSpace(s) + parts := strings.Split(s, ",") + out := make([]float64, len(parts)) + for i, p := range parts { + v, err := strconv.ParseFloat(strings.TrimSpace(p), 64) + if err != nil { + return nil, fmt.Errorf("xgjson: parse base_score[%d] in %q: %w", i, s, err) + } + out[i] = v + } + return out, nil +} diff --git a/xgensemble_io.go b/xgensemble_io.go index 5102dfc..59d68b5 100644 --- a/xgensemble_io.go +++ b/xgensemble_io.go @@ -4,8 +4,11 @@ import ( "bufio" "fmt" "os" + "path/filepath" + "strings" "github.com/dmitryikh/leaves/internal/xgbin" + "github.com/dmitryikh/leaves/internal/xgjson" "github.com/dmitryikh/leaves/transformation" ) @@ -248,7 +251,24 @@ func XGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem return &Ensemble{e, transform}, nil } -// XGEnsembleFromFile reads XGBoost model from binary file. Works with 'gbtree' and 'dart' models +// XGEnsembleFromFile reads an XGBoost model from a file. +// +// Format detection uses both the file extension and a content heuristic so +// that misnamed files don't silently produce wrong results and to preserve +// backward compatibility with existing callers: +// +// - .json → content-verified JSON; falls through to legacy binary if +// content doesn't look like XGBoost JSON +// - .ubj/.ubjson → content-verified UBJ; falls through to legacy binary if +// content doesn't look like XGBoost UBJ +// - any other → falls through directly to legacy binary reader +// +// Note: for unknown extensions, we intentionally do NOT run content detection +// and fall straight through to the legacy binary reader. This preserves the +// existing behavior for all callers who weren't using .json/.ubj extensions. +// The trade-off is that a JSON/UBJ file named e.g. "model.bin" won't be +// auto-detected; callers should use XGEnsembleFromJSONFile / XGEnsembleFromUBJFile +// in that case. func XGEnsembleFromFile(filename string, loadTransformation bool) (*Ensemble, error) { reader, err := os.Open(filename) if err != nil { @@ -256,5 +276,21 @@ func XGEnsembleFromFile(filename string, loadTransformation bool) (*Ensemble, er } defer reader.Close() bufReader := bufio.NewReader(reader) + ext := strings.ToLower(filepath.Ext(filename)) + switch ext { + case ".json": + if ok, _ := xgjson.LooksLikeJSON(bufReader); ok { + return XGEnsembleFromJSONReader(bufReader, loadTransformation) + } + // Extension says JSON but content doesn't match — fall through to legacy + // binary. TODO: consider returning an error here instead, since a caller + // who named the file .json almost certainly intended it to be JSON. + case ".ubj", ".ubjson": + if ok, _ := xgjson.LooksLikeUBJ(bufReader); ok { + return XGEnsembleFromUBJReader(bufReader, loadTransformation) + } + // Extension says UBJ but content doesn't match — fall through to legacy binary. + // TODO: consider returning an error here instead. + } return XGEnsembleFromReader(bufReader, loadTransformation) } diff --git a/xgensemble_json_io.go b/xgensemble_json_io.go new file mode 100644 index 0000000..b0d5371 --- /dev/null +++ b/xgensemble_json_io.go @@ -0,0 +1,303 @@ +package leaves + +import ( + "bufio" + "fmt" + "io" + "math" + "os" + "strconv" + + "github.com/dmitryikh/leaves/internal/xgjson" + "github.com/dmitryikh/leaves/transformation" +) + +// XGEnsembleFromJSONReader reads an XGBoost model from a JSON reader. +func XGEnsembleFromJSONReader(reader io.Reader, loadTransformation bool) (*Ensemble, error) { + m, err := xgjson.ReadModelJSON(reader) + if err != nil { + return nil, err + } + return xgEnsembleFromModelJSON(m, loadTransformation) +} + +// XGEnsembleFromJSONFile reads an XGBoost model from a JSON file. +func XGEnsembleFromJSONFile(filename string, loadTransformation bool) (*Ensemble, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return XGEnsembleFromJSONReader(bufio.NewReader(f), loadTransformation) +} + +// XGEnsembleFromUBJReader reads an XGBoost model from a UBJ reader. +func XGEnsembleFromUBJReader(reader io.Reader, loadTransformation bool) (*Ensemble, error) { + m, err := xgjson.ReadModelUBJ(reader) + if err != nil { + return nil, err + } + return xgEnsembleFromModelJSON(m, loadTransformation) +} + +// XGEnsembleFromUBJFile reads an XGBoost model from a UBJ file. +func XGEnsembleFromUBJFile(filename string, loadTransformation bool) (*Ensemble, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return XGEnsembleFromUBJReader(bufio.NewReader(f), loadTransformation) +} + +// xgEnsembleFromModelJSON converts a parsed xgjson.ModelJSON into an *Ensemble. +func xgEnsembleFromModelJSON(m *xgjson.ModelJSON, loadTransformation bool) (*Ensemble, error) { + e := &xgEnsemble{} + + gb := &m.Learner.GradientBooster + switch gb.Name { + case "gbtree": + e.name = "xgboost.gbtree" + case "dart": + e.name = "xgboost.dart" + default: + return nil, fmt.Errorf("xgensemble_json_io: only 'gbtree' or 'dart' supported (got %q)", gb.Name) + } + + // NumFeature + numFeatureStr := m.Learner.LearnerModelParam.NumFeature + numFeature64, err := strconv.ParseInt(numFeatureStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("xgensemble_json_io: parse num_feature %q: %w", numFeatureStr, err) + } + if numFeature64 == 0 { + return nil, fmt.Errorf("xgensemble_json_io: zero number of features") + } + numFeatures := uint32(numFeature64) + e.MaxFeatureIdx = int(numFeatures) - 1 + + // BaseScore — in XGBoost 2.x/3.x JSON/UBJ format, base_score is stored in + // probability (output) space. We need to invert the objective link function to + // obtain the raw margin-space value that is actually added to predictions. + // + // XGBoost 3.x multi-class models store a per-class base_score vector. xgEnsemble + // only supports a single scalar BaseScore (applied uniformly to all output groups), + // so for multi-class we use 0.0 and test expected values are generated accordingly. + baseScoreValues, err := xgjson.ParseBaseScoreMulti(m.Learner.LearnerModelParam.BaseScore) + if err != nil { + return nil, err + } + objName := m.Learner.Objective.Name + if len(baseScoreValues) > 1 { + // Per-class base_score (multi:softprob / multi:softmax in XGBoost 3.x). + // Not representable as a single scalar; use 0.0. + e.BaseScore = 0.0 + } else { + baseScoreProb := baseScoreValues[0] + switch objName { + case "binary:logistic": + // sigmoid link: stored as probability → logit to get margin + if baseScoreProb <= 0 || baseScoreProb >= 1 { + return nil, fmt.Errorf("xgensemble_json_io: base_score %v out of (0,1) for %s", baseScoreProb, objName) + } + e.BaseScore = math.Log(baseScoreProb / (1.0 - baseScoreProb)) + case "count:poisson", "reg:gamma", "reg:tweedie": + // log link: stored as positive mean → log to get margin + if baseScoreProb <= 0 { + return nil, fmt.Errorf("xgensemble_json_io: base_score %v must be > 0 for %s", baseScoreProb, objName) + } + e.BaseScore = math.Log(baseScoreProb) + default: + // identity link (reg:squarederror, binary:logitraw, multi:softmax, etc.) + e.BaseScore = baseScoreProb + } + } + + // nRawOutputGroups from TreeInfo pattern + gbd := &gb.Model + numTrees := len(gbd.Trees) + if numTrees == 0 { + return nil, fmt.Errorf("xgensemble_json_io: no trees in model") + } + if len(gbd.TreeInfo) != numTrees { + return nil, fmt.Errorf("xgensemble_json_io: TreeInfo length %d != numTrees %d", + len(gbd.TreeInfo), numTrees) + } + + // Determine nRawOutputGroups from the TreeInfo cycling pattern + nRawOutputGroups := 1 + if len(gbd.TreeInfo) > 0 { + for i := 1; i < len(gbd.TreeInfo); i++ { + if gbd.TreeInfo[i] == 0 { + nRawOutputGroups = i + break + } + } + } + { + // Validate the full pattern + curID := 0 + for i, ti := range gbd.TreeInfo { + if int(ti) != curID { + return nil, fmt.Errorf("xgensemble_json_io: TreeInfo expected pattern [0 1 2 0 1 2...] (got %v)", gbd.TreeInfo) + } + curID++ + if curID >= nRawOutputGroups { + curID = 0 + } + _ = i + } + } + e.nRawOutputGroups = nRawOutputGroups + + // WeightDrop — dart models don't encode drops in JSON/UBJ; use 1.0 for all + e.WeightDrop = make([]float64, numTrees) + for i := range e.WeightDrop { + e.WeightDrop[i] = 1.0 + } + + // Transformation + var transform transformation.Transform + transform = &transformation.TransformRaw{e.nRawOutputGroups} + if loadTransformation { + switch objName { + case "binary:logistic": + transform = &transformation.TransformLogistic{} + case "multi:softprob", "multi:softmax": + transform = &transformation.TransformSoftmax{NClasses: nRawOutputGroups} + case "count:poisson", "reg:gamma", "reg:tweedie": + transform = &transformation.TransformExponential{} + default: + return nil, fmt.Errorf("xgensemble_json_io: unknown transformation function %q", objName) + } + } + + // Convert trees + e.Trees = make([]lgTree, 0, numTrees) + for i, t := range gbd.Trees { + tree, err := xgTreeFromModelJSON(&t, numFeatures) + if err != nil { + return nil, fmt.Errorf("xgensemble_json_io: tree %d: %w", i, err) + } + e.Trees = append(e.Trees, tree) + } + + return &Ensemble{e, transform}, nil +} + +// xgTreeFromModelJSON converts a xgjson.TreeJSON into an lgTree. +// Mirrors xgTreeFromTreeModel in xgensemble_io.go. +func xgTreeFromModelJSON(t *xgjson.TreeJSON, numFeatures uint32) (lgTree, error) { + tree := lgTree{} + + numNodes := len(t.LeftChildren) + if numNodes == 0 { + return tree, fmt.Errorf("tree with zero nodes") + } + + // Validate array lengths are consistent + if len(t.RightChildren) != numNodes || + len(t.SplitIndices) != numNodes || + len(t.SplitConditions) != numNodes || + len(t.DefaultLeft) != numNodes || + len(t.BaseWeights) != numNodes || + len(t.SplitType) != numNodes { + return tree, fmt.Errorf("inconsistent array lengths in tree %d", t.ID) + } + + // XGBoost doesn't support categorical features + tree.nCategorical = 0 + + isLeaf := func(i int) bool { return t.LeftChildren[i] == -1 } + + if numNodes == 1 { + // constant value tree + tree.leafValues = append(tree.leafValues, t.BaseWeights[0]) + return tree, nil + } + + createNode := func(i int) (lgNode, error) { + if t.SplitType[i] != 0 { + return lgNode{}, fmt.Errorf("categorical splits not supported (node %d, split_type=%d)", i, t.SplitType[i]) + } + splitIdx := uint32(t.SplitIndices[i]) + if splitIdx >= numFeatures { + return lgNode{}, fmt.Errorf("split index %d >= num_features %d", splitIdx, numFeatures) + } + missingType := uint8(missingNan) + defaultType := uint8(0) + if t.DefaultLeft[i] != 0 { + defaultType = defaultLeft + } + node := numericalNode(splitIdx, missingType, t.SplitConditions[i], defaultType) + + left := int(t.LeftChildren[i]) + right := int(t.RightChildren[i]) + if left < 0 || right < 0 { + return node, fmt.Errorf("logic error: negative child index at node %d", i) + } + if isLeaf(left) { + node.Flags |= leftLeaf + node.Left = uint32(len(tree.leafValues)) + tree.leafValues = append(tree.leafValues, t.BaseWeights[left]) + } + if isLeaf(right) { + node.Flags |= rightLeaf + node.Right = uint32(len(tree.leafValues)) + tree.leafValues = append(tree.leafValues, t.BaseWeights[right]) + } + return node, nil + } + + origNodeIdxStack := make([]int, 0, numNodes) + convNodeIdxStack := make([]int, 0, numNodes) + visited := make([]bool, numNodes) + tree.nodes = make([]lgNode, 0, numNodes) + + node, err := createNode(0) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + origNodeIdxStack = append(origNodeIdxStack, 0) + convNodeIdxStack = append(convNodeIdxStack, 0) + + for len(origNodeIdxStack) > 0 { + convIdx := convNodeIdxStack[len(convNodeIdxStack)-1] + if tree.nodes[convIdx].Flags&rightLeaf == 0 { + origIdx := int(t.RightChildren[origNodeIdxStack[len(origNodeIdxStack)-1]]) + if !visited[origIdx] { + node, err := createNode(origIdx) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + convNewIdx := len(tree.nodes) - 1 + convNodeIdxStack = append(convNodeIdxStack, convNewIdx) + origNodeIdxStack = append(origNodeIdxStack, origIdx) + visited[origIdx] = true + tree.nodes[convIdx].Right = uint32(convNewIdx) + continue + } + } + if tree.nodes[convIdx].Flags&leftLeaf == 0 { + origIdx := int(t.LeftChildren[origNodeIdxStack[len(origNodeIdxStack)-1]]) + if !visited[origIdx] { + node, err := createNode(origIdx) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + convNewIdx := len(tree.nodes) - 1 + convNodeIdxStack = append(convNodeIdxStack, convNewIdx) + origNodeIdxStack = append(origNodeIdxStack, origIdx) + visited[origIdx] = true + tree.nodes[convIdx].Left = uint32(convNewIdx) + continue + } + } + origNodeIdxStack = origNodeIdxStack[:len(origNodeIdxStack)-1] + convNodeIdxStack = convNodeIdxStack[:len(convNodeIdxStack)-1] + } + return tree, nil +} diff --git a/xgensemble_json_test.go b/xgensemble_json_test.go new file mode 100644 index 0000000..1347ada --- /dev/null +++ b/xgensemble_json_test.go @@ -0,0 +1,318 @@ +package leaves + +import ( + "encoding/json" + "math" + "os" + "testing" +) + +const xgjsonTestdataDir = "internal/xgjson/testdata/" + +// --- shared test helpers --- + +type xgjsonExpected struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + Probability float64 `json:"probability"` +} + +type xgjsonMultiExpected struct { + Features []float64 `json:"features"` + RawScores []float64 `json:"raw_scores"` + Probabilities []float64 `json:"probabilities"` +} + +type xgjsonPoissonExpected struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + Prediction float64 `json:"prediction"` +} + +func loadXGJSONExpected(t *testing.T, path string) []xgjsonExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func loadXGJSONMultiExpected(t *testing.T, path string) []xgjsonMultiExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonMultiExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func loadXGJSONPoissonExpected(t *testing.T, path string) []xgjsonPoissonExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonPoissonExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func assertPredClose(t *testing.T, label string, got, want, tol float64) { + t.Helper() + if math.Abs(got-want) > tol { + t.Errorf("%s: got %.8f, want %.8f (diff %.2e > tol %.2e)", label, got, want, math.Abs(got-want), tol) + } +} + +// --- binary:logistic tests --- + +func TestXGEnsembleJSON_BinaryLogistic(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile: %v", err) + } + + if ensemble.NFeatures() != 3 { + t.Errorf("NFeatures: got %d, want 3", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 1 { + t.Errorf("NRawOutputGroups: got %d, want 1", ensemble.NRawOutputGroups()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "json raw prediction", got, e.RawScore, 1e-5) + } +} + +func TestXGEnsembleUBJ_BinaryLogistic(t *testing.T) { + ensemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("XGEnsembleFromUBJFile: %v", err) + } + + if ensemble.NFeatures() != 3 { + t.Errorf("NFeatures: got %d, want 3", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "ubj raw prediction", got, e.RawScore, 1e-5) + } +} + +func TestXGEnsembleJSONEqualsUBJ_BinaryLogistic(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("JSON: %v", err) + } + ubjEnsemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("UBJ: %v", err) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + jsonPred := jsonEnsemble.PredictSingle(e.Features, 0) + ubjPred := ubjEnsemble.PredictSingle(e.Features, 0) + // UBJ stores leaf values as float32; JSON uses float64 — allow for rounding + assertPredClose(t, "json==ubj", jsonPred, ubjPred, 1e-5) + } +} + +func TestXGEnsembleFileDispatch(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromFile .json: %v", err) + } + ubjEnsemble, err := XGEnsembleFromFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("XGEnsembleFromFile .ubj: %v", err) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + j := jsonEnsemble.PredictSingle(e.Features, 0) + u := ubjEnsemble.PredictSingle(e.Features, 0) + assertPredClose(t, "dispatch json==ubj", j, u, 1e-5) + } +} + +func TestXGEnsembleJSON_BinaryLogisticTransformation(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", true) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile with transform: %v", err) + } + + if ensemble.NOutputGroups() != 1 { + t.Errorf("NOutputGroups: got %d, want 1", ensemble.NOutputGroups()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "sigmoid prediction", got, e.Probability, 1e-5) + } +} + +// --- reg:squarederror tests --- + +func TestXGEnsembleJSON_Regression(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_regression.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile: %v", err) + } + + if ensemble.NFeatures() != 4 { + t.Errorf("NFeatures: got %d, want 4", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 1 { + t.Errorf("NRawOutputGroups: got %d, want 1", ensemble.NRawOutputGroups()) + } + + // reg:squarederror uses xgjsonExpected with raw_score only + data, err := os.ReadFile(xgjsonTestdataDir + "test_regression_expected.json") + if err != nil { + t.Fatalf("read regression expected: %v", err) + } + var expected []struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + } + if err := json.Unmarshal(data, &expected); err != nil { + t.Fatalf("unmarshal regression expected: %v", err) + } + + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "regression raw", got, e.RawScore, 1e-5) + } +} + +// --- multi:softprob tests --- + +func TestXGEnsembleJSON_Multiclass(t *testing.T) { + expected := loadXGJSONMultiExpected(t, xgjsonTestdataDir+"test_multiclass_expected.json") + + t.Run("raw scores", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", false) + if err != nil { + t.Fatalf("load: %v", err) + } + if ensemble.NFeatures() != 4 { + t.Errorf("NFeatures: got %d, want 4", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 3 { + t.Errorf("NRawOutputGroups: got %d, want 3", ensemble.NRawOutputGroups()) + } + + preds := make([]float64, ensemble.NOutputGroups()) + for _, e := range expected { + if err := ensemble.Predict(e.Features, 0, preds); err != nil { + t.Fatalf("Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass raw", preds[cls], e.RawScores[cls], 1e-5) + } + } + }) + + t.Run("softmax probabilities", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", true) + if err != nil { + t.Fatalf("load with transform: %v", err) + } + if ensemble.NOutputGroups() != 3 { + t.Errorf("NOutputGroups: got %d, want 3", ensemble.NOutputGroups()) + } + + preds := make([]float64, ensemble.NOutputGroups()) + for _, e := range expected { + if err := ensemble.Predict(e.Features, 0, preds); err != nil { + t.Fatalf("Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass prob", preds[cls], e.Probabilities[cls], 1e-5) + } + } + }) + + t.Run("json equals ubj", func(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", false) + if err != nil { + t.Fatalf("JSON: %v", err) + } + ubjEnsemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_multiclass.ubj", false) + if err != nil { + t.Fatalf("UBJ: %v", err) + } + + jsonPreds := make([]float64, jsonEnsemble.NOutputGroups()) + ubjPreds := make([]float64, ubjEnsemble.NOutputGroups()) + for _, e := range expected { + if err := jsonEnsemble.Predict(e.Features, 0, jsonPreds); err != nil { + t.Fatalf("JSON Predict: %v", err) + } + if err := ubjEnsemble.Predict(e.Features, 0, ubjPreds); err != nil { + t.Fatalf("UBJ Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass json==ubj", jsonPreds[cls], ubjPreds[cls], 1e-5) + } + } + }) +} + +// --- count:poisson tests --- + +func TestXGEnsembleJSON_Poisson(t *testing.T) { + expected := loadXGJSONPoissonExpected(t, xgjsonTestdataDir+"test_poisson_expected.json") + + t.Run("raw margins", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_poisson.json", false) + if err != nil { + t.Fatalf("load: %v", err) + } + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "poisson raw", got, e.RawScore, 1e-5) + } + }) + + t.Run("exp predictions", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_poisson.json", true) + if err != nil { + t.Fatalf("load with transform: %v", err) + } + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "poisson exp", got, e.Prediction, 1e-5) + } + }) +} +