diff --git a/internal/ubjdecode/decode.go b/internal/ubjdecode/decode.go new file mode 100644 index 0000000..f9eb870 --- /dev/null +++ b/internal/ubjdecode/decode.go @@ -0,0 +1,536 @@ +// Package ubjdecode implements a standalone Universal Binary JSON (UBJSON) decoder. +// Zero external dependencies — uses only encoding/binary and io from stdlib. +// +// UBJSON spec: https://ubjson.org/ +// XGBoost uses UBJSON for its .ubj model format. +package ubjdecode + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +// DecodeValue reads one UBJSON value from r and returns it as a native Go value: +// +// map[string]interface{} for objects +// []interface{} for untyped/mixed arrays +// []int32, []int64, []float32, []float64 for optimized typed arrays +// string, bool, int64, float64, nil for scalars +func DecodeValue(r io.Reader) (interface{}, error) { + marker, err := readByte(r) + if err != nil { + return nil, err + } + return decodeMarker(r, marker) +} + +func decodeMarker(r io.Reader, marker byte) (interface{}, error) { + for marker == 'N' { // noop — skip and read next + var err error + marker, err = readByte(r) + if err != nil { + return nil, err + } + } + + switch marker { + case 'Z': // null + return nil, nil + case 'T': // true + return true, nil + case 'F': // false + return false, nil + case 'i': // int8 + b, err := readByte(r) + if err != nil { + return nil, err + } + return int64(int8(b)), nil + case 'U': // uint8 + b, err := readByte(r) + if err != nil { + return nil, err + } + return int64(b), nil + case 'I': // int16 big-endian + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return int64(v), nil + case 'l': // int32 big-endian + var v int32 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return int64(v), nil + case 'L': // int64 big-endian + var v int64 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + return v, nil + case 'd': // float32 big-endian + var bits uint32 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + return float64(math.Float32frombits(bits)), nil + case 'D': // float64 big-endian + var bits uint64 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + return math.Float64frombits(bits), nil + case 'C': // char — single byte as string + b, err := readByte(r) + if err != nil { + return nil, err + } + return string([]byte{b}), nil + case 'S': // string — length marker + bytes + n, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("string length: %w", err) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return string(buf), nil + case 'H': // high-precision number — decode as string + n, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("high-precision length: %w", err) + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return string(buf), nil + case '[': // array + return decodeArray(r) + case '{': // object + return decodeObject(r) + default: + return nil, fmt.Errorf("ubjdecode: unknown marker byte 0x%02X ('%c')", marker, marker) + } +} + +// readByte reads a single byte from r. +func readByte(r io.Reader) (byte, error) { + var buf [1]byte + _, err := io.ReadFull(r, buf[:]) + return buf[0], err +} + +// readLength reads an integer length value (used for string lengths, array/object counts). +// The next byte is a type marker ('i','U','I','l','L') indicating the integer width. +func readLength(r io.Reader) (int64, error) { + marker, err := readByte(r) + if err != nil { + return 0, err + } + switch marker { + case 'i': + b, err := readByte(r) + return int64(int8(b)), err + case 'U': + b, err := readByte(r) + return int64(b), err + case 'I': + var v int16 + err := binary.Read(r, binary.BigEndian, &v) + return int64(v), err + case 'l': + var v int32 + err := binary.Read(r, binary.BigEndian, &v) + return int64(v), err + case 'L': + var v int64 + err := binary.Read(r, binary.BigEndian, &v) + return v, err + default: + return 0, fmt.Errorf("ubjdecode: unexpected length marker 0x%02X ('%c')", marker, marker) + } +} + +// decodeArray decodes a UBJSON array. The '[' marker has already been consumed. +// Handles all four container forms: +// 1. [$][type][#][len] — typed + counted → returns typed Go slice +// 2. [#][len] — counted only → returns []interface{} +// 3. [$][type] + ']' — typed, terminated +// 4. ']' terminator — untyped, terminated → returns []interface{} +func decodeArray(r io.Reader) (interface{}, error) { + next, err := readByte(r) + if err != nil { + return nil, err + } + + // Typed array: [$][type][#][count] + if next == '$' { + typeMarker, err := readByte(r) + if err != nil { + return nil, err + } + // Peek for '#' + after, err := readByte(r) + if err != nil { + return nil, err + } + if after == '#' { + // typed + counted: fast path + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("typed array count: %w", err) + } + return readTypedArray(r, typeMarker, int(count)) + } + // typed + unterminated: read until ']' + // `after` is the first element's marker? No — in typed arrays the element + // marker is NOT repeated; `after` here should be ']' or data bytes. + // Actually when typed without count, each element has no marker prefix — + // the type is fixed. So `after` is the first data byte of the first element. + // We must decode accordingly. + return readTypedArrayUntilEnd(r, typeMarker, after) + } + + // Counted only: [#][len] + if next == '#' { + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("counted array length: %w", err) + } + result := make([]interface{}, 0, int(count)) + for i := int64(0); i < count; i++ { + v, err := DecodeValue(r) + if err != nil { + return nil, fmt.Errorf("array element %d: %w", i, err) + } + result = append(result, v) + } + return result, nil + } + + // Unterminated untyped: read until ']' + if next == ']' { + return []interface{}{}, nil + } + // next is the marker of the first element + result := []interface{}{} + v, err := decodeMarker(r, next) + if err != nil { + return nil, fmt.Errorf("array element 0: %w", err) + } + result = append(result, v) + for { + marker, err := readByte(r) + if err != nil { + return nil, err + } + if marker == ']' { + break + } + v, err := decodeMarker(r, marker) + if err != nil { + return nil, fmt.Errorf("array element: %w", err) + } + result = append(result, v) + } + return result, nil +} + +// readTypedArray reads `count` elements of a fixed type without per-element markers. +// Returns typed Go slices for performance-critical XGBoost arrays. +func readTypedArray(r io.Reader, typeMarker byte, count int) (interface{}, error) { + switch typeMarker { + case 'i': // int8 → []int32 + buf := make([]int8, count) + for i := range buf { + b, err := readByte(r) + if err != nil { + return nil, err + } + buf[i] = int8(b) + } + out := make([]int32, count) + for i, v := range buf { + out[i] = int32(v) + } + return out, nil + case 'U': // uint8 → []int32 + buf := make([]byte, count) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + out := make([]int32, count) + for i, v := range buf { + out[i] = int32(v) + } + return out, nil + case 'I': // int16 → []int32 + out := make([]int32, count) + for i := range out { + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return nil, err + } + out[i] = int32(v) + } + return out, nil + case 'l': // int32 → []int32 + out := make([]int32, count) + for i := range out { + if err := binary.Read(r, binary.BigEndian, &out[i]); err != nil { + return nil, err + } + } + return out, nil + case 'L': // int64 → []int64 + out := make([]int64, count) + for i := range out { + if err := binary.Read(r, binary.BigEndian, &out[i]); err != nil { + return nil, err + } + } + return out, nil + case 'd': // float32 → []float32 + out := make([]float32, count) + for i := range out { + var bits uint32 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + out[i] = math.Float32frombits(bits) + } + return out, nil + case 'D': // float64 → []float64 + out := make([]float64, count) + for i := range out { + var bits uint64 + if err := binary.Read(r, binary.BigEndian, &bits); err != nil { + return nil, err + } + out[i] = math.Float64frombits(bits) + } + return out, nil + default: + // Fall back to []interface{} for other typed arrays + result := make([]interface{}, 0, count) + for i := 0; i < count; i++ { + v, err := decodeMarker(r, typeMarker) + if err != nil { + return nil, fmt.Errorf("typed array element %d: %w", i, err) + } + result = append(result, v) + } + return result, nil + } +} + +// readTypedArrayUntilEnd reads a typed array without a count, terminated by ']'. +// firstDataByte is the first raw data byte already consumed (not a marker). +func readTypedArrayUntilEnd(r io.Reader, typeMarker byte, firstDataByte byte) (interface{}, error) { + // For typed arrays without count, we need to reconstruct the first element + // and then keep reading until we see ']'. + // Strategy: build []interface{} since we don't know the count. + result := []interface{}{} + + // Decode first element from firstDataByte + first, err := decodeSingleFromByte(typeMarker, firstDataByte, r) + if err != nil { + return nil, err + } + result = append(result, first) + + for { + // For typed arrays, next byte is raw data (no marker), or ']' to end + b, err := readByte(r) + if err != nil { + return nil, err + } + if b == ']' { + break + } + v, err := decodeSingleFromByte(typeMarker, b, r) + if err != nil { + return nil, err + } + result = append(result, v) + } + return result, nil +} + +// decodeSingleFromByte decodes one element of a typed array given the first +// byte already consumed and any remaining bytes read from r. +func decodeSingleFromByte(typeMarker byte, firstByte byte, r io.Reader) (interface{}, error) { + switch typeMarker { + case 'i': + return int64(int8(firstByte)), nil + case 'U': + return int64(firstByte), nil + case 'I': + var buf [1]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return nil, err + } + v := int16(firstByte)<<8 | int16(buf[0]) + return int64(v), nil + case 'l': + var rest [3]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + v := int32(firstByte)<<24 | int32(rest[0])<<16 | int32(rest[1])<<8 | int32(rest[2]) + return int64(v), nil + case 'L': + var rest [7]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + v := int64(firstByte)<<56 | int64(rest[0])<<48 | int64(rest[1])<<40 | + int64(rest[2])<<32 | int64(rest[3])<<24 | int64(rest[4])<<16 | + int64(rest[5])<<8 | int64(rest[6]) + return v, nil + case 'd': + var rest [3]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + bits := uint32(firstByte)<<24 | uint32(rest[0])<<16 | uint32(rest[1])<<8 | uint32(rest[2]) + return float64(math.Float32frombits(bits)), nil + case 'D': + var rest [7]byte + if _, err := io.ReadFull(r, rest[:]); err != nil { + return nil, err + } + bits := uint64(firstByte)<<56 | uint64(rest[0])<<48 | uint64(rest[1])<<40 | + uint64(rest[2])<<32 | uint64(rest[3])<<24 | uint64(rest[4])<<16 | + uint64(rest[5])<<8 | uint64(rest[6]) + return math.Float64frombits(bits), nil + default: + // treat firstByte as the marker + return decodeMarker(r, firstByte) + } +} + +// decodeObject decodes a UBJSON object. The '{' marker has already been consumed. +// Returns map[string]interface{}. +func decodeObject(r io.Reader) (map[string]interface{}, error) { + result := make(map[string]interface{}) + + // Check for optimized form: [{][#][len] or [{][$][type][#][len] + next, err := readByte(r) + if err != nil { + return nil, err + } + + if next == '#' { + // counted object + count, err := readLength(r) + if err != nil { + return nil, fmt.Errorf("object count: %w", err) + } + for i := int64(0); i < count; i++ { + k, v, err := readKeyValue(r) + if err != nil { + return nil, fmt.Errorf("object key/value %d: %w", i, err) + } + result[k] = v + } + return result, nil + } + + if next == '}' { + return result, nil + } + + // Unoptimized: next byte is the length marker of the first key + // In UBJSON objects, keys are NOT prefixed with 'S' — just length_marker + bytes + k, v, err := readKeyValueFromLenMarker(r, next) + if err != nil { + return nil, fmt.Errorf("object first key/value: %w", err) + } + result[k] = v + + for { + marker, err := readByte(r) + if err != nil { + return nil, err + } + if marker == '}' { + break + } + k, v, err := readKeyValueFromLenMarker(r, marker) + if err != nil { + return nil, fmt.Errorf("object key/value: %w", err) + } + result[k] = v + } + return result, nil +} + +// readKeyValue reads one key-value pair from an object. +// The key length marker has NOT been consumed yet. +func readKeyValue(r io.Reader) (string, interface{}, error) { + lenMarker, err := readByte(r) + if err != nil { + return "", nil, err + } + return readKeyValueFromLenMarker(r, lenMarker) +} + +// readKeyValueFromLenMarker reads one key-value pair where lenMarker is the +// already-consumed first byte of the key length. +func readKeyValueFromLenMarker(r io.Reader, lenMarker byte) (string, interface{}, error) { + // lenMarker is the type marker for the key length + var keyLen int64 + switch lenMarker { + case 'i': + b, err := readByte(r) + if err != nil { + return "", nil, err + } + keyLen = int64(int8(b)) + case 'U': + b, err := readByte(r) + if err != nil { + return "", nil, err + } + keyLen = int64(b) + case 'I': + var v int16 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = int64(v) + case 'l': + var v int32 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = int64(v) + case 'L': + var v int64 + if err := binary.Read(r, binary.BigEndian, &v); err != nil { + return "", nil, err + } + keyLen = v + default: + return "", nil, fmt.Errorf("ubjdecode: unexpected key length marker 0x%02X ('%c')", lenMarker, lenMarker) + } + + keyBuf := make([]byte, keyLen) + if _, err := io.ReadFull(r, keyBuf); err != nil { + return "", nil, fmt.Errorf("reading key: %w", err) + } + key := string(keyBuf) + + val, err := DecodeValue(r) + if err != nil { + return "", nil, fmt.Errorf("value for key %q: %w", key, err) + } + return key, val, nil +} diff --git a/internal/xgjson/detect.go b/internal/xgjson/detect.go new file mode 100644 index 0000000..3250ded --- /dev/null +++ b/internal/xgjson/detect.go @@ -0,0 +1,121 @@ +package xgjson + +import ( + "bufio" + "encoding/binary" + "bytes" +) + +// LooksLikeJSON peeks at the buffered reader and returns true if the content +// appears to be an XGBoost JSON model file. It does not advance the reader. +// +// Heuristic: the first non-whitespace byte must be '{', and the string +// "\"learner\"" (the mandatory top-level key in every XGBoost model) must +// appear within the first 512 peeked bytes. +func LooksLikeJSON(r *bufio.Reader) (bool, error) { + buf, err := r.Peek(512) + if len(buf) == 0 { + return false, err + } + + // Find first non-whitespace byte; it must be '{' + firstNonWS := -1 + for i, b := range buf { + if b != ' ' && b != '\t' && b != '\r' && b != '\n' { + firstNonWS = i + break + } + } + if firstNonWS < 0 || buf[firstNonWS] != '{' { + return false, nil + } + + // Confirm the mandatory top-level key is present in the peeked window. + // Every XGBoost JSON model has "learner" as its first object key. + return bytes.Contains(buf, []byte(`"learner"`)), nil +} + +// LooksLikeUBJ peeks at the buffered reader and returns true if the content +// appears to be an XGBoost UBJ (Universal Binary JSON) model file. It does +// not advance the reader. +// +// Heuristic: UBJ objects start with '{' (0x7B), and the first key in an +// XGBoost UBJ model is always "learner". We verify the first byte is '{' and +// then parse the UBJ key-length encoding + key bytes from the peek buffer, +// checking that the first key is exactly "learner". +// +// UBJ key lengths are encoded as: , where the +// marker is one of: 'i' (int8), 'U' (uint8), 'I' (int16 BE), 'l' (int32 BE), +// 'L' (int64 BE). XGBoost currently uses 'L', so we handle all widths for +// robustness. +func LooksLikeUBJ(r *bufio.Reader) (bool, error) { + // 32 bytes is more than enough: 1 ('{') + 1 (marker) + 8 (int64) + 7 ("learner") = 17 + buf, err := r.Peek(32) + if len(buf) < 2 { + return false, err + } + + if buf[0] != '{' { + return false, nil + } + + // Parse the key length from buf[1:] + keyLen, headerLen, ok := ubjReadLength(buf[1:]) + if !ok { + return false, nil + } + + // "learner" is 7 bytes + if keyLen != 7 { + return false, nil + } + + start := 1 + headerLen + end := start + 7 + if end > len(buf) { + return false, nil + } + + return bytes.Equal(buf[start:end], []byte("learner")), nil +} + +// ubjReadLength parses a UBJ integer-length from b (starting at b[0] which is +// the type marker). Returns (value, bytesConsumed, ok). +func ubjReadLength(b []byte) (value int, consumed int, ok bool) { + if len(b) < 1 { + return 0, 0, false + } + marker := b[0] + switch marker { + case 'i': // int8 — 1 data byte + if len(b) < 2 { + return 0, 0, false + } + return int(int8(b[1])), 2, true + case 'U': // uint8 — 1 data byte + if len(b) < 2 { + return 0, 0, false + } + return int(b[1]), 2, true + case 'I': // int16 BE — 2 data bytes + if len(b) < 3 { + return 0, 0, false + } + v := int16(binary.BigEndian.Uint16(b[1:3])) + return int(v), 3, true + case 'l': // int32 BE — 4 data bytes + if len(b) < 5 { + return 0, 0, false + } + v := int32(binary.BigEndian.Uint32(b[1:5])) + return int(v), 5, true + case 'L': // int64 BE — 8 data bytes + if len(b) < 9 { + return 0, 0, false + } + v := binary.BigEndian.Uint64(b[1:9]) + return int(v), 9, true + default: + return 0, 0, false + } +} diff --git a/internal/xgjson/detect_test.go b/internal/xgjson/detect_test.go new file mode 100644 index 0000000..d5ec3ab --- /dev/null +++ b/internal/xgjson/detect_test.go @@ -0,0 +1,110 @@ +package xgjson_test + +import ( + "bufio" + "bytes" + "os" + "testing" + + "github.com/dmitryikh/leaves/internal/xgjson" +) + +const testdataDir = "testdata/" + +func TestLooksLikeJSON(t *testing.T) { + t.Run("real JSON model", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.json") + if ok, err := xgjson.LooksLikeJSON(r); err != nil || !ok { + t.Errorf("LooksLikeJSON(json) = (%v, %v), want (true, nil)", ok, err) + } + }) + + t.Run("UBJ file is not JSON", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.ubj") + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(ubj) = true, want false") + } + }) + + t.Run("empty reader", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader(nil)) + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(empty) = true, want false") + } + }) + + t.Run("random bytes", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03, 0xFF})) + if ok, _ := xgjson.LooksLikeJSON(r); ok { + t.Error("LooksLikeJSON(random bytes) = true, want false") + } + }) + + t.Run("reader is not consumed", func(t *testing.T) { + f, err := os.Open(testdataDir + "test_binary_logistic.json") + if err != nil { + t.Fatal(err) + } + defer f.Close() + r := bufio.NewReader(f) + xgjson.LooksLikeJSON(r) //nolint:errcheck + // Peek must not consume: a second call should return the same result. + if ok, _ := xgjson.LooksLikeJSON(r); !ok { + t.Error("second LooksLikeJSON call returned false — reader was consumed") + } + }) +} + +func TestLooksLikeUBJ(t *testing.T) { + t.Run("real UBJ model", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.ubj") + if ok, err := xgjson.LooksLikeUBJ(r); err != nil || !ok { + t.Errorf("LooksLikeUBJ(ubj) = (%v, %v), want (true, nil)", ok, err) + } + }) + + t.Run("JSON file is not UBJ", func(t *testing.T) { + r := openBuf(t, testdataDir+"test_binary_logistic.json") + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(json) = true, want false") + } + }) + + t.Run("empty reader", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader(nil)) + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(empty) = true, want false") + } + }) + + t.Run("random bytes", func(t *testing.T) { + r := bufio.NewReader(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03, 0xFF})) + if ok, _ := xgjson.LooksLikeUBJ(r); ok { + t.Error("LooksLikeUBJ(random bytes) = true, want false") + } + }) + + t.Run("reader is not consumed", func(t *testing.T) { + f, err := os.Open(testdataDir + "test_binary_logistic.ubj") + if err != nil { + t.Fatal(err) + } + defer f.Close() + r := bufio.NewReader(f) + xgjson.LooksLikeUBJ(r) //nolint:errcheck + // Peek must not consume: a second call should return the same result. + if ok, _ := xgjson.LooksLikeUBJ(r); !ok { + t.Error("second LooksLikeUBJ call returned false — reader was consumed") + } + }) +} + +func openBuf(t *testing.T, path string) *bufio.Reader { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatalf("open %s: %v", path, err) + } + t.Cleanup(func() { f.Close() }) + return bufio.NewReader(f) +} diff --git a/internal/xgjson/testdata/README.md b/internal/xgjson/testdata/README.md new file mode 100644 index 0000000..105be25 --- /dev/null +++ b/internal/xgjson/testdata/README.md @@ -0,0 +1,22 @@ +# XGBoost Test Model Data + +This directory contains generated XGBoost model files used by the `xgjson` package tests. + +## Files + +- `train.py` — Python script that generates the test models +- `test1.json` — XGBoost model in JSON format (XGBoost 2.x/3.x) +- `test1.ubj` — XGBoost model in Universal Binary JSON format (XGBoost 2.x/3.x) +- `test1_expected.json` — Input feature vectors with expected raw scores and probabilities + +## Regenerating + +Requires [uv](https://docs.astral.sh/uv/). + +```sh +cd internal/xgjson/testdata +go generate +``` + +This trains a small XGBoost binary classifier (10 estimators, max_depth=3) on synthetic data +with 3 features, then writes the three output files above. Commit all three output files. diff --git a/internal/xgjson/testdata/generate.go b/internal/xgjson/testdata/generate.go new file mode 100644 index 0000000..73179f5 --- /dev/null +++ b/internal/xgjson/testdata/generate.go @@ -0,0 +1,3 @@ +//go:generate uv run train.py + +package testdata diff --git a/internal/xgjson/testdata/test_binary_logistic.json b/internal/xgjson/testdata/test_binary_logistic.json new file mode 100644 index 0000000..8d20caf --- /dev/null +++ b/internal/xgjson/testdata/test_binary_logistic.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"10"},"iteration_indptr":[0,1,2,3,4,5,6,7,8,9,10],"tree_info":[0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[3.4848927E-8,-1.5204959E0,1.6623036E0,1.1777768E0,-1.6315176E0,7.10179E-1,1.8392156E0,4.33977E-1,7.083949E-2,-5.601467E-1,-1.6758814E-1,3.6204177E-1,-1.5752013E-1,4.071052E-1,5.7988805E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.1048126E2,3.3520996E1,1.4945984E1,8.6061764E-1,2.4676788E1,1.0907537E1,1.0265808E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.733439E-1,-1.8622892E0,5.279944E-1,1.3041685E0,-2.5623593E-1,1.3718903E0,7.7473426E-1,4.33977E-1,7.083949E-2,-5.601467E-1,-1.6758814E-1,3.6204177E-1,-1.5752013E-1,4.071052E-1,5.7988805E-1],"split_indices":[2,0,2,1,2,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.9996877E2,1.0448368E2,9.5485085E1,3.7494144E0,1.0073427E2,1.574754E1,7.973755E1,2.4996095E0,1.2498047E0,8.19872E1,1.8747072E1,1.1248243E1,4.499297E0,1.549758E1,6.423997E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.3946685E-3,-1.1543779E0,1.2591227E0,9.5409256E-1,-1.2454361E0,5.2215135E-1,1.4025306E0,3.727428E-1,2.3597144E-2,-4.3348977E-1,-1.21060215E-1,4.5374906E-1,2.7029645E-2,3.0377525E-1,4.4462258E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.7547772E2,2.011644E1,8.874084E0,1.0756211E0,1.5365356E1,6.8806324E0,8.410034E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.733439E-1,-1.8622892E0,5.279944E-1,-7.9844646E-2,-2.5623593E-1,-1.8622892E0,7.7473426E-1,3.727428E-1,2.3597144E-2,-4.3348977E-1,-1.21060215E-1,4.5374906E-1,2.7029645E-2,3.0377525E-1,4.4462258E-1],"split_indices":[2,0,2,2,2,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.874821E2,9.758766E1,8.989444E1,3.647652E0,9.394001E1,1.5399192E1,7.449525E1,2.428422E0,1.2192299E0,7.53628E1,1.8577206E1,3.9179015E0,1.1481291E1,1.4947725E1,5.954752E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[7.773815E-3,-9.943017E-1,9.9700433E-1,2.7761754E-1,-1.076667E0,5.286407E-1,3.7626263E-1,-3.394221E-1,3.191901E-1,2.0238668E-1,-3.8180283E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,-1,7,9,-1,-1,-1,-1,-1],"loss_changes":[1.6563177E2,1.4121422E1,9.465561E0,0E0,1.0584633E1,9.1505995E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5],"right_children":[2,4,6,-1,8,10,-1,-1,-1,-1,-1],"split_conditions":[2.0596176E-1,-1.8622892E0,7.2267E-1,2.7761754E-1,-4.3772988E-2,2.3320758E0,3.7626263E-1,-3.394221E-1,3.191901E-1,2.0238668E-1,-3.8180283E-1],"split_indices":[2,0,2,0,0,1,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.6508813E2,8.200907E1,8.307906E1,2.9528005E0,7.905627E1,3.0533514E1,5.254555E1,7.754477E1,1.5115029E0,2.8823126E1,1.7103883E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[9.2651285E-3,-1.0554368E0,6.792614E-1,2.3023954E-1,-1.1098394E0,1.0147718E-1,1.0298864E0,-5.1912617E-2,-3.520825E-1,1.7861223E-1,-3.123424E-1,1.9472034E-1,3.3861062E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0307969E2,6.3869324E0,1.786119E1,0E0,3.0589828E0,2.0146532E1,1.3401375E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.5623593E-1,-2.0155988E0,5.279944E-1,2.3023954E-1,-2.6036727E-1,1.3718903E0,7.7473426E-1,-5.1912617E-2,-3.520825E-1,1.7861223E-1,-3.123424E-1,1.9472034E-1,3.3861062E-1],"split_indices":[2,0,2,0,1,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.4249661E2,5.4814583E1,8.768203E1,1.2086432E0,5.3605938E1,3.360115E1,5.4080883E1,3.6531324E0,4.9952805E1,2.379479E1,9.806361E0,1.2701178E1,4.1379704E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[6.5199593E-3,-9.5052385E-1,5.596426E-1,2.0603801E-1,-1.0059912E0,7.9649486E-2,9.219858E-1,-3.384241E-1,-1.2668663E-1,1.2908778E-1,-2.4213417E-1,1.5698004E-1,3.0862683E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[6.568722E1,4.8187065E0,1.3602236E1,0E0,2.6854439E0,1.1204861E1,1.3651428E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.5623593E-1,-2.0155988E0,5.594684E-1,2.0603801E-1,-5.7463634E-1,1.3718903E0,7.7473426E-1,-3.384241E-1,-1.2668663E-1,1.2908778E-1,-2.4213417E-1,1.5698004E-1,3.0862683E-1],"split_indices":[2,0,2,0,2,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2208248E2,4.445151E1,7.763097E1,1.0835884E0,4.3367924E1,3.391881E1,4.371216E1,3.5094765E1,8.273158E0,2.4676153E1,9.242659E0,1.0479339E1,3.323282E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[3.5492254E-3,-5.9674454E-1,6.862838E-1,-1.0244098E0,-1.5051265E-1,3.5410693E-1,9.457186E-1,-1.983951E-2,-3.1750667E-1,-1.2765785E-1,4.3486968E-1,3.8029623E-1,1.0213314E-2,2.9135826E-1,1.1577517E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[4.385949E1,1.0713873E1,3.979826E0,8.703499E-1,1.3396497E1,6.8439E0,1.7572403E-2,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[3.263706E-1,-5.7463634E-1,7.7473426E-1,-1.534747E0,-7.9241884E-1,-1.8355771E0,1.0084084E0,-1.983951E-2,-3.1750667E-1,-1.2765785E-1,4.3486968E-1,3.8029623E-1,1.0213314E-2,2.9135826E-1,1.1577517E-1],"split_indices":[2,2,2,0,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.0501652E2,5.5949192E1,4.9067326E1,2.7907331E1,2.804186E1,2.255417E1,2.6513155E1,1.0556241E0,2.6851707E1,2.455545E1,3.4864097E0,5.08358E0,1.747059E1,2.4656202E1,1.8569524E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[9.643211E-3,-9.6156096E-1,3.2266313E-1,-1.6817693E-2,-1.0025502E0,1.896752E-2,3.062792E-1,-3.163531E-2,-3.1356356E-1,3.1077072E-1,-8.825563E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,-1,7,9,-1,-1,-1,-1,-1],"loss_changes":[2.8589937E1,7.9467964E-1,1.5057805E1,0E0,7.268925E-1,1.6408224E1,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5],"right_children":[2,4,6,-1,8,10,-1,-1,-1,-1,-1],"split_conditions":[-5.7463634E-1,-1.534747E0,-7.9241884E-1,-1.6817693E-2,-2.3389478E0,-1.8355771E0,3.062792E-1,-3.163531E-2,-3.1356356E-1,3.1077072E-1,-8.825563E-2],"split_indices":[2,0,0,0,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[9.202333E1,2.1924986E1,7.009835E1,1.052031E0,2.0872955E1,4.9567326E1,2.0531021E1,1.1051567E0,1.97678E1,1.1126637E1,3.844069E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[1.0875948E-2,-9.042081E-1,2.6562464E-1,-1.4255263E-2,-9.5357305E-1,2.432597E-2,9.072643E-1,-2.707678E-2,-3.0165362E-1,1.6966063E-1,-1.2472395E-1,2.841285E-1,9.571609E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.9578981E1,7.282772E-1,1.0144684E1,0E0,6.824322E-1,1.1845299E1,1.1762619E-1,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-5.7463634E-1,-1.534747E0,-7.9241884E-1,-1.4255263E-2,-2.3389478E0,-1.451989E0,1.0084084E0,-2.707678E-2,-3.0165362E-1,1.6966063E-1,-1.2472395E-1,2.841285E-1,9.571609E-2],"split_indices":[2,0,0,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[8.195767E1,1.7292343E1,6.466533E1,1.0490721E0,1.6243273E1,4.7747166E1,1.6918161E1,1.0771346E0,1.5166138E1,2.1285168E1,2.6462E1,1.5273769E1,1.6443915E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[8.420152E-3,-8.866207E-1,2.0744449E-1,-2.3173474E-2,-2.8409058E-1,3.5733518E-1,-6.553415E-1,3.0889487E-1,2.7870974E-2,2.424394E-1,-3.562318E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,-1,-1,7,9,-1,-1,-1,-1],"loss_changes":[1.3613471E1,6.0597706E-1,8.315063E0,0E0,0E0,9.483938E0,8.436519E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,5,5,6,6],"right_children":[2,4,6,-1,-1,8,10,-1,-1,-1,-1],"split_conditions":[-6.423459E-1,-2.3389478E0,1.8273275E0,-2.3173474E-2,-2.8409058E-1,-1.4220484E0,-2.035454E0,3.0889487E-1,2.7870974E-2,2.424394E-1,-3.562318E-1],"split_indices":[2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[7.438975E1,1.2903961E1,6.148579E1,1.053575E0,1.1850386E1,5.2884132E1,8.601657E0,1.4111651E1,3.8772484E1,2.1552174E0,6.4464393E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"11","size_leaf_vector":"1"}},{"base_weights":[3.1257947E-3,-2.370009E-1,7.368995E-1,-4.209105E-5,-8.483909E-1,8.1515753E-1,1.465431E-1,3.7057182E-1,-1.14696495E-1,2.5674397E-1,-3.5496542E-1,7.9826936E-2,2.756763E-1,-1.3072331E-1,1.8127826E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.2430462E1,7.701972E0,6.841221E-1,1.9037214E1,9.485901E0,6.1796856E-1,1.1308331E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[7.2267E-1,1.398677E0,1.8586534E0,-1.2937937E0,-2.0780978E0,6.0014886E-1,1.2276509E0,3.7057182E-1,-1.14696495E-1,2.5674397E-1,-3.5496542E-1,7.9826936E-2,2.756763E-1,-1.3072331E-1,1.8127826E-1],"split_indices":[2,1,1,0,0,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[6.855681E1,5.2163403E1,1.6393408E1,3.8313972E1,1.3849428E1,1.4138422E1,2.254987E0,8.527461E0,2.9786512E1,2.0186934E0,1.1830735E1,2.8137515E0,1.132467E1,1.0135728E0,1.2414142E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[4.9375E-1]","boost_from_average":"1","num_class":"0","num_feature":"3","num_target":"1"},"objective":{"name":"binary:logistic","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_binary_logistic.py b/internal/xgjson/testdata/test_binary_logistic.py new file mode 100644 index 0000000..dad9bd4 --- /dev/null +++ b/internal/xgjson/testdata/test_binary_logistic.py @@ -0,0 +1,74 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates binary:logistic XGBoost test model. +Run with: uv run test_binary_logistic.py +Outputs: test_binary_logistic.json, test_binary_logistic.ubj, test_binary_logistic_expected.json +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_classification +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + X, y = make_classification( + n_samples=1000, + n_features=3, + n_informative=3, + n_redundant=0, + n_clusters_per_class=1, + random_state=42, + ) + + X_train, X_test = X[:800], X[800:] + y_train = y[:800] + + model = xgb.XGBClassifier( + n_estimators=10, + max_depth=3, + objective="binary:logistic", + random_state=42, + eval_metric="logloss", + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_binary_logistic.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_binary_logistic.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + raw_preds = model.get_booster().predict(dtest, output_margin=True) + prob_preds = model.predict_proba(test_rows)[:, 1] + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + "probability": float(prob_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_binary_logistic_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f} prob={e['probability']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_binary_logistic.ubj b/internal/xgjson/testdata/test_binary_logistic.ubj new file mode 100644 index 0000000..2123e14 Binary files /dev/null and b/internal/xgjson/testdata/test_binary_logistic.ubj differ diff --git a/internal/xgjson/testdata/test_binary_logistic_expected.json b/internal/xgjson/testdata/test_binary_logistic_expected.json new file mode 100644 index 0000000..d116578 --- /dev/null +++ b/internal/xgjson/testdata/test_binary_logistic_expected.json @@ -0,0 +1,47 @@ +[ + { + "features": [ + -1.1096492747839761, + 0.40340213424485294, + 0.07954668421267086 + ], + "raw_score": -0.7728347778320312, + "probability": 0.3158662021160126 + }, + { + "features": [ + -1.403825230749223, + 1.3655851426765793, + 0.5469370969348255 + ], + "raw_score": 1.4077504873275757, + "probability": 0.8034109473228455 + }, + { + "features": [ + -0.8717338325160207, + 1.0131613603890717, + 1.5004776410785952 + ], + "raw_score": 2.404935121536255, + "probability": 0.9172028303146362 + }, + { + "features": [ + -0.7083348159161926, + 1.087972292415291, + 1.2560764067197803 + ], + "raw_score": 3.208322286605835, + "probability": 0.9611462950706482 + }, + { + "features": [ + -1.2069009394127748, + 1.2786601324536067, + -0.011722010905850233 + ], + "raw_score": -0.7728347778320312, + "probability": 0.3158662021160126 + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/test_multiclass.json b/internal/xgjson/testdata/test_multiclass.json new file mode 100644 index 0000000..356a9f3 --- /dev/null +++ b/internal/xgjson/testdata/test_multiclass.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"30"},"iteration_indptr":[0,3,6,9,12,15,18,21,24,27,30],"tree_info":[0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2,0,1,2],"trees":[{"base_weights":[8.976922E-8,1.1891924E0,-3.6806875E-1,-1.565994E-1,1.2743015E0,-2.0320473E-2,-5.801027E-1,4.2051739E-1,9.365237E-2,-1.289252E-1,6.1291154E-2,-1.9667643E-1,-6.719352E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[9.881445E1,8.436264E0,1.2699818E1,0E0,5.8297577E0,6.1886563E0,2.7311745E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-5.058216E-1,-2.000158E0,6.0872287E-1,-1.565994E-1,1.4570175E0,-1.0269129E0,5.924044E-1,4.2051739E-1,9.365237E-2,-1.289252E-1,6.1291154E-2,-1.9667643E-1,-6.719352E-2],"split_indices":[3,0,3,0,2,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.23756E2,5.2358906E1,1.713971E2,2.23756E0,5.0121346E1,6.5336754E1,1.0606034E2,4.3856174E1,6.265168E0,2.2823112E1,4.251364E1,8.681733E1,1.9243015E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.6837087E-7,4.479209E-1,-5.045543E-1,5.4555234E-2,1.2178032E0,8.804552E-1,-6.205394E-1,-1.6066408E-1,6.819169E-2,3.925975E-1,-1.4185265E-1,-5.911352E-2,3.8677734E-1,-2.1556853E-1,-8.2272686E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.010972E1,3.5503872E1,1.72964E1,8.15283E0,6.7124557E0,3.98522E0,3.088787E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.717547E-1,-4.0281177E-1,-1.014656E0,-4.4667578E-1,1.850578E0,-3.530614E-1,1.1128337E0,-1.6066408E-1,6.819169E-2,3.925975E-1,-1.4185265E-1,-5.911352E-2,3.8677734E-1,-2.1556853E-1,-8.2272686E-2],"split_indices":[2,0,0,3,3,3,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.1972401E2,1.1645373E2,1.0327029E2,7.77823E1,3.8671425E1,7.4706163E0,9.579967E1,1.7138474E1,6.064383E1,3.6913635E1,1.7577921E0,2.19724E0,5.2733765E0,7.382727E1,2.19724E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.2448177E-8,-6.637632E-1,1.5248057E-1,-7.0995367E-1,1.2725432E-1,-1.4294602E-1,6.348038E-1,-2.201662E-1,-1.0262389E-3,-1.0314276E-1,2.357759E-1,-1.4246517E-1,2.6561716E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[2.2782997E1,2.4143295E0,2.617415E1,6.85936E-1,0E0,2.160259E1,1.9806606E1,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,3.0433664E0,-1.8242843E-1,1.5687288E0,1.2725432E-1,1.5553991E0,-3.7348688E-1,-2.201662E-1,-1.0262389E-3,-1.0314276E-1,2.357759E-1,-1.4246517E-1,2.6561716E-1],"split_indices":[3,0,2,1,0,3,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.2310402E2,4.105114E1,1.8205287E2,3.9712513E1,1.3386241E0,1.1333684E2,6.8716034E1,3.837389E1,1.3386241E0,9.370369E1,1.9633154E1,1.2493825E1,5.622221E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-2.0901388E-3,7.066297E-1,-3.7673816E-1,-2.7283943E-1,7.987066E-1,-1.166203E-1,-5.032613E-1,1.8392135E-1,-1.826072E-1,2.6300237E-1,2.2548191E-2,-2.0539433E-1,1.1224851E-2,-1.7371142E-1,-5.1528372E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.8382786E1,7.0756645E0,4.6751633E0,2.5390375E0,3.869522E0,4.2579136E0,2.3238049E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-1.947818E-1,-2.2962296E0,6.0872287E-1,-1.5332314E0,1.5481571E0,-1.6530755E0,5.924044E-1,1.8392135E-1,-1.826072E-1,2.6300237E-1,2.2548191E-2,-2.0539433E-1,1.1224851E-2,-1.7371142E-1,-5.1528372E-2],"split_indices":[3,1,3,3,2,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.1787793E2,7.503966E1,1.4283827E2,6.2793884E0,6.876027E1,4.73708E1,9.546747E1,1.4992453E0,4.7801433E0,6.1880505E1,6.879761E0,9.370666E0,3.8000137E1,7.70973E1,1.837017E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.3279037E-3,3.171575E-1,-4.6415395E-1,2.3645068E-2,8.152216E-1,5.511289E-1,-5.8502877E-1,1.2800452E-1,-3.8428452E-2,2.995066E-1,1.11927696E-1,-1.4434238E-1,2.6591742E-1,-1.937455E-1,-1.0518843E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.192071E1,1.8734081E1,1.1155401E1,5.053966E0,3.4765015E0,3.8548107E0,9.04459E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.877946E-1,-4.0281177E-1,-9.587663E-1,-1.1083441E0,-7.3509055E-1,-3.530614E-1,1.1128337E0,1.2800452E-1,-3.8428452E-2,2.995066E-1,1.11927696E-1,-1.4434238E-1,2.6591742E-1,-1.937455E-1,-1.0518843E-1],"split_indices":[2,0,0,1,2,3,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.1455482E2,1.2728244E2,8.7272385E1,8.074586E1,4.6536583E1,8.90614E0,7.836624E1,2.1589642E1,5.915621E1,3.2016865E1,1.4519718E1,2.0767446E0,6.8293953E0,6.083427E1,1.7531973E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.388171E-3,-5.8328724E-1,1.2091643E-1,-6.3269967E-1,1.24666505E-1,-1.1077672E-1,4.6495634E-1,-1.9731566E-1,2.7590208E-3,-8.5615814E-2,1.7995414E-1,-1.2866977E-1,1.9218363E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.5205626E1,2.0794935E0,1.469448E1,5.7564735E-1,0E0,1.3892177E1,1.1912209E1,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,3.0433664E0,-1.8242843E-1,1.5687288E0,1.24666505E-1,1.5553991E0,-3.7348688E-1,-1.9731566E-1,2.7590208E-3,-8.5615814E-2,1.7995414E-1,-1.2866977E-1,1.9218363E-1],"split_indices":[3,0,2,1,0,3,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.185053E2,3.5806103E1,1.8269919E2,3.446757E1,1.3385327E0,1.0958043E2,7.311876E1,3.3150803E1,1.316767E0,8.844642E1,2.1134008E1,1.1772185E1,6.1346577E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-3.627935E-3,5.643761E-1,-3.2544547E-1,-3.6645105E-1,6.4271986E-1,-9.8873034E-2,-4.4622588E-1,-1.9266257E-1,1.1406814E-1,2.2241534E-1,-5.9353437E-2,-1.7357367E-1,1.7972209E-2,-1.5769641E-1,-3.8165253E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.8057682E1,5.726818E0,3.5925999E0,1.578974E0,5.937004E0,3.6521955E0,2.1223698E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-1.947818E-1,-1.7600034E0,6.0872287E-1,-5.051962E-1,1.5481571E0,-1.543846E0,5.924044E-1,-1.9266257E-1,1.1406814E-1,2.2241534E-1,-5.9353437E-2,-1.7357367E-1,1.7972209E-2,-1.5769641E-1,-3.8165253E-2],"split_indices":[3,0,3,2,2,0,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.061954E2,7.430112E1,1.3189429E2,5.482639E0,6.881848E1,4.6494225E1,8.540007E1,4.106113E0,1.3765259E0,6.16948E1,7.123678E0,1.0904812E1,3.5589413E1,6.785685E1,1.7543219E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-2.3265986E-3,2.4472557E-1,-4.2390215E-1,2.2704393E-2,7.0940715E-1,5.565847E-1,-5.341267E-1,3.319532E-2,-1.5480052E-1,-8.584619E-2,2.3229305E-1,2.346827E-1,-1.1742429E-1,-1.847629E-1,-8.020587E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.145066E1,1.3381815E1,8.521408E0,4.269415E0,2.932415E0,2.056324E0,1.3383408E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.2600034E-1,-1.3960569E-1,-1.014656E0,1.850578E0,3.7759027E-1,2.864812E-1,1.1128337E0,3.319532E-2,-1.5480052E-1,-8.584619E-2,2.3229305E-1,2.346827E-1,-1.1742429E-1,-1.847629E-1,-8.020587E-2],"split_indices":[2,0,0,3,3,1,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.0396097E2,1.2886365E2,7.509732E1,8.790987E1,4.095378E1,7.179908E0,6.791741E1,7.625536E1,1.1654502E1,2.2978382E0,3.865594E1,5.9404693E0,1.2394388E0,5.0985043E1,1.693237E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[5.8407024E-3,-7.454842E-2,6.61053E-1,-4.675233E-1,1.1418001E-1,-1.7490596E-1,9.3166107E-1,-1.788098E-1,1.9935963E-1,5.713473E-2,-2.0596418E-1,9.640786E-2,2.8833634E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,7,9,-1,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.1128508E1,1.4015454E1,8.316058E0,9.310222E0,7.946671E0,0E0,3.9007187E-2,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,6,6],"right_children":[2,4,6,8,10,-1,12,-1,-1,-1,-1,-1,-1],"split_conditions":[1.5481571E0,5.406891E-2,-7.3064186E-2,1.5989536E0,1.9810429E0,-1.7490596E-1,6.802459E-1,-1.788098E-1,1.9935963E-1,5.713473E-2,-2.0596418E-1,9.640786E-2,2.8833634E-1],"split_indices":[2,3,0,1,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[2.0934401E2,1.8725552E2,2.208848E1,6.027411E1,1.2698142E2,3.740002E0,1.8348478E1,5.4554634E1,5.7194715E0,1.1663419E2,1.0347224E1,1.3932775E0,1.6955202E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-6.520784E-3,5.034979E-1,-2.6448688E-1,-3.0338088E-1,5.7349336E-1,-6.451221E-2,-3.969988E-1,-1.7097509E-1,9.132237E-2,1.9818838E-1,-1.3702532E-2,-1.8452558E-1,1.6406905E-2,-1.3624011E-1,1.7421275E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.5519024E1,3.8538609E0,3.3872843E0,1.121987E0,3.283474E0,3.4807763E0,2.03539E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.662504E-1,-1.7600034E0,6.0872287E-1,-5.051962E-1,1.4570175E0,-1.6530755E0,1.8039737E0,-1.7097509E-1,9.132237E-2,1.9818838E-1,-1.3702532E-2,-1.8452558E-1,1.6406905E-2,-1.3624011E-1,1.7421275E-2],"split_indices":[3,0,3,2,2,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.9194925E2,6.4156044E1,1.27793205E2,4.855022E0,5.930102E1,5.152433E1,7.6268875E1,3.4182942E0,1.4367279E0,5.1926525E1,7.374498E0,8.429569E0,4.309476E1,6.7763145E1,8.505735E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.0645691E-3,2.2012648E-1,-3.5846508E-1,1.4714547E-2,5.941625E-1,4.735961E-1,-4.6156082E-1,1.12558745E-1,-3.77464E-2,1.9603299E-1,-1.4754216E-1,1.9787346E-1,-1.1305776E-1,-1.7235164E-1,-5.2932955E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":10,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.5373497E1,9.220861E0,6.595866E0,4.027528E0,3.061223E0,1.6229713E0,2.0450363E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.717547E-1,-4.0281177E-1,-1.014656E0,-1.0269129E0,1.850578E0,2.864812E-1,1.1128337E0,1.12558745E-1,-3.77464E-2,1.9603299E-1,-1.4754216E-1,1.9787346E-1,-1.1305776E-1,-1.7235164E-1,-5.2932955E-2],"split_indices":[2,0,0,1,3,1,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.9247069E2,1.191279E2,7.334279E1,7.756844E1,4.155945E1,7.6894355E0,6.565335E1,2.1289669E1,5.6278778E1,3.9785908E1,1.7735406E0,6.493782E0,1.1956534E0,4.6298138E1,1.9355217E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[7.2899703E-3,-4.8482883E-1,8.1380315E-2,-5.448387E-1,1.0623165E-1,1.1662075E-2,6.649444E-1,-1.7231289E-1,-4.9774544E-3,-2.6839187E-2,1.4042485E-1,-1.4367083E-1,2.4241816E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":11,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[7.3624496E0,1.5232654E0,7.129156E0,3.814006E-1,0E0,7.321403E0,3.4486837E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,1.5989536E0,1.553888E0,2.996951E0,1.0623165E-1,1.7482485E0,-2.538463E-1,-1.7231289E-1,-4.9774544E-3,-2.6839187E-2,1.4042485E-1,-1.4367083E-1,2.4241816E-1],"split_indices":[3,1,2,0,0,3,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.9984319E2,2.5424725E1,1.7441847E2,2.405463E1,1.3700948E0,1.5671567E2,1.7702793E1,2.2698114E1,1.3565165E0,1.2894925E2,2.776642E1,1.7077156E0,1.5995077E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-9.485977E-3,5.173073E-1,-2.0132104E-1,-1.4228354E-1,5.712821E-1,-3.0152557E-2,-3.5412034E-1,1.8488286E-1,-9.377627E-2,-1.0566302E-1,3.632692E-2,-1.3348909E-1,-1.315392E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":12,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[1.8166344E1,2.83951E0,3.4349093E0,0E0,2.0207405E0,3.122415E0,1.9397707E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-5.058216E-1,-2.000158E0,6.0872287E-1,-1.4228354E-1,1.5989536E0,-1.0269129E0,5.924044E-1,1.8488286E-1,-9.377627E-2,-1.0566302E-1,3.632692E-2,-1.3348909E-1,-1.315392E-2],"split_indices":[3,0,3,0,1,1,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.7773206E2,4.6991978E1,1.3074008E2,2.023871E0,4.4968105E1,6.222824E1,6.851183E1,4.307794E1,1.8901649E0,1.9460339E1,4.2767902E1,5.265979E1,1.5852043E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[5.2180095E-4,1.7266046E-1,-3.6015958E-1,1.3660084E-2,5.300831E-1,3.773546E-1,-4.7569916E-1,-1.1902709E-2,2.6027045E-1,1.8343304E-1,-3.5189666E-2,-1.1239097E-1,1.705419E-1,-5.4319464E-2,-1.6800156E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":13,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.1327617E1,7.0113926E0,5.2022257E0,3.9741178E0,2.0731316E0,1.4404948E0,1.175416E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-2.2600034E-1,-1.3960569E-1,-9.587663E-1,2.6605442E0,2.8335938E0,-3.530614E-1,-1.4435455E0,-1.1902709E-2,2.6027045E-1,1.8343304E-1,-3.5189666E-2,-1.1239097E-1,1.705419E-1,-5.4319464E-2,-1.6800156E-1],"split_indices":[2,0,0,1,0,3,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.8044528E2,1.2250154E2,5.7943745E1,8.550337E1,3.6998173E1,7.54112E0,5.0402626E1,8.137421E1,4.129162E0,3.2916832E1,4.0813394E0,1.3335973E0,6.207523E0,1.19137945E1,3.848883E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[8.39162E-3,-4.4714794E-1,6.7530744E-2,-5.1339555E-1,9.307784E-2,7.640994E-3,5.6538594E-1,-1.6428992E-1,1.9248568E-3,-2.700297E-2,1.01534545E-1,-1.4207219E-1,2.1525909E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":14,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[5.1692576E0,1.265688E0,5.0557585E0,3.7108374E-1,0E0,4.9567223E0,3.2181735E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,1.5989536E0,1.5481571E0,2.996951E0,9.307784E-2,1.5553991E0,-2.538463E-1,-1.6428992E-1,1.9248568E-3,-2.700297E-2,1.01534545E-1,-1.4207219E-1,2.1525909E-1],"split_indices":[3,1,2,0,0,3,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.8975803E2,2.105025E1,1.6870778E2,1.9653406E1,1.3968422E0,1.5149854E2,1.7209246E1,1.8388708E1,1.2646985E0,1.1753218E2,3.396636E1,1.9283444E0,1.5280901E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.2626532E-2,3.045353E-1,-2.5842E-1,-3.556971E-1,4.0415043E-1,-2.9166606E-1,7.693084E-1,-1.5010996E-1,1.1268311E-1,1.4982644E-1,-7.5805195E-2,-1.18902445E-1,-1.5447201E-2,5.9276357E-2,2.8622952E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":15,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.2976739E1,4.9250126E0,3.4277816E0,1.2525285E0,4.108309E0,2.2809978E0,1.5090215E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.406891E-2,-1.5891352E0,2.9755492E0,-2.877946E-1,1.5481571E0,4.6791342E-1,1.7804539E-1,-1.5010996E-1,1.1268311E-1,1.4982644E-1,-7.5805195E-2,-1.18902445E-1,-1.5447201E-2,5.9276357E-2,2.8622952E-1],"split_indices":[3,0,0,2,2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.644522E2,7.169758E1,9.275461E1,9.062452E0,6.2635124E1,9.054187E1,2.212743E0,7.8294077E0,1.2330447E0,5.491028E1,7.724848E0,6.2606094E1,2.7935774E1,1.0459479E0,1.1667951E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[2.145788E-3,1.712665E-1,-2.877501E-1,1.1636712E-2,4.8477724E-1,4.2308635E-1,-3.870616E-1,9.962344E-2,-3.300727E-2,1.6604947E-1,-1.4108709E-1,1.7400391E-1,-9.964812E-2,-1.6987023E-1,-4.1386925E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":16,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[8.399696E0,5.4090395E0,4.6201315E0,2.8713953E0,2.6545057E0,1.1664158E0,2.417324E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-4.4955203E-1,-4.8287207E-1,-1.014656E0,-1.0269129E0,1.850578E0,2.864812E-1,8.159176E-1,9.962344E-2,-3.300727E-2,1.6604947E-1,-1.4108709E-1,1.7400391E-1,-9.964812E-2,-1.6987023E-1,-4.1386925E-2],"split_indices":[2,0,0,1,3,1,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.6932138E2,1.0719482E2,6.212656E1,7.171627E1,3.5478546E1,7.2160983E0,5.491046E1,1.9259457E1,5.245681E1,3.348932E1,1.98923E0,6.166775E0,1.0493231E0,3.1197876E1,2.3712585E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[9.509712E-3,-4.0005688E-2,4.2434493E-1,-3.4024462E-1,7.0136264E-2,-1.5706638E-1,6.43454E-1,-1.5161529E-1,1.8823531E-1,3.856378E-2,-1.911911E-1,6.0481496E-2,2.002758E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":17,"left_children":[1,3,5,7,9,-1,11,-1,-1,-1,-1,-1,-1],"loss_changes":[3.7399385E0,5.4086227E0,4.368891E0,7.328385E0,5.0388503E0,0E0,3.8328648E-2,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,6,6],"right_children":[2,4,6,8,10,-1,12,-1,-1,-1,-1,-1,-1],"split_conditions":[1.5481571E0,5.406891E-2,-7.3064186E-2,1.5989536E0,1.9810429E0,-1.5706638E-1,6.802459E-1,-1.5161529E-1,1.8823531E-1,3.856378E-2,-1.911911E-1,6.0481496E-2,2.002758E-1],"split_indices":[2,3,0,1,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.802478E2,1.6183453E2,1.8413269E1,4.2873997E1,1.1896053E2,3.1962652E0,1.5217004E1,3.703212E1,5.841878E0,1.1064356E2,8.316973E0,1.2724842E0,1.3944519E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.3786456E-2,4.232589E-1,-1.5453258E-1,-1.23297E-1,4.7504672E-1,-5.4358995E-1,-9.4504744E-2,1.7179418E-1,1.6227977E-2,-1.7619376E-1,3.070284E-3,1.2520383E-1,-5.3179897E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":18,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[9.444844E0,1.8115249E0,2.6593513E0,0E0,1.4352713E0,3.7659693E-1,4.384926E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-5.058216E-1,-2.000158E0,-1.7600034E0,-1.23297E-1,9.893304E-1,1.7568178E-1,-1.8138796E0,1.7179418E-1,1.6227977E-2,-1.7619376E-1,3.070284E-3,1.2520383E-1,-5.3179897E-2],"split_indices":[3,0,0,0,1,2,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.514803E2,3.6410633E1,1.1506967E2,1.7227559E0,3.4687878E1,1.4304238E1,1.00765434E2,2.7866507E1,6.8213706E0,1.3201552E1,1.1026858E0,1.3462419E1,8.730302E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[3.3527268E-3,-4.1702497E-1,9.393785E-2,1.7970294E-1,-5.137885E-1,2.1195391E-1,-1.9290258E-1,7.9478994E-2,-1.691505E-1,8.768038E-2,-1.2975918E-1,8.9961015E-2,-1.0973162E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":19,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[6.1304564E0,3.1138854E0,4.5234766E0,0E0,1.1859779E0,5.0072107E0,3.442437E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.4058717E-1,-2.850353E0,-3.717547E-1,1.7970294E-1,-2.399855E0,1.850578E0,-4.6335888E-1,7.9478994E-2,-1.691505E-1,8.768038E-2,-1.2975918E-1,8.9961015E-2,-1.0973162E-1],"split_indices":[3,1,2,0,0,3,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.5896027E2,2.7542112E1,1.3141815E2,1.9435924E0,2.559852E1,9.329461E1,3.812355E1,1.2865309E0,2.4311989E1,8.34427E1,9.851909E0,9.71002E0,2.841353E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[9.11443E-3,-3.864425E-1,4.776072E-2,-4.6021813E-1,8.523351E-2,-6.5666504E-2,2.1095058E-1,-1.5276419E-1,1.9150585E-2,-7.7280626E-2,1.0416328E-1,1.0260255E-1,-1.8611878E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":20,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[2.6447253E0,8.926511E-1,2.9273286E0,3.9855504E-1,0E0,7.5084167E0,7.2701354E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,3.0433664E0,-1.8242843E-1,1.5687288E0,8.523351E-2,1.0844262E0,1.6783428E0,-1.5276419E-1,1.9150585E-2,-7.7280626E-2,1.0416328E-1,1.0260255E-1,-1.8611878E-1],"split_indices":[3,0,2,1,0,1,1,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.7079967E2,1.4401037E1,1.5639862E2,1.3256236E1,1.1448007E0,9.261971E1,6.377892E1,1.2148722E1,1.1075149E0,6.348386E1,2.9135849E1,5.5602116E1,8.176804E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.2654852E-2,4.538828E-1,-1.0541267E-1,4.9121177E-1,-2.8421942E-2,-5.156002E-1,-4.9706228E-2,-5.619883E-2,1.627176E-1,-1.6806856E-1,5.9875115E-3,1.1171465E-1,-3.6554433E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":21,"left_children":[1,3,5,7,-1,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[6.18774E0,5.322919E-1,2.697228E0,8.78458E-1,0E0,3.4917474E-1,3.2737217E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,5,5,6,6],"right_children":[2,4,6,8,-1,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,1.5989536E0,-1.7600034E0,-1.6238737E0,-2.8421942E-2,1.7568178E-1,-1.8138796E0,-5.619883E-2,1.627176E-1,-1.6806856E-1,5.9875115E-3,1.1171465E-1,-3.6554433E-2],"split_indices":[3,1,0,0,0,2,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.4088081E2,2.2719069E1,1.1816175E2,2.1369694E1,1.3493742E0,1.3141353E1,1.050204E2,1.3122215E0,2.0057472E1,1.2087991E1,1.053362E0,1.4722142E1,9.0298256E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[3.821175E-3,1.24513425E-1,-2.3297627E-1,-6.062304E-3,4.2964745E-1,3.8214183E-1,-3.2982954E-1,-1.7670315E-2,2.2356662E-1,1.4584626E-1,-1.2079639E-1,-4.772599E-2,1.683034E-1,-1.6292854E-1,-2.9046834E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":22,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[4.3606954E0,4.0286374E0,3.195314E0,2.8921854E0,1.620605E0,8.125682E-1,2.1858974E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-4.4955203E-1,-4.0281177E-1,-1.014656E0,2.6605442E0,1.850578E0,-3.530614E-1,8.159176E-1,-1.7670315E-2,2.2356662E-1,1.4584626E-1,-1.2079639E-1,-4.772599E-2,1.683034E-1,-1.6292854E-1,-2.9046834E-2],"split_indices":[2,0,0,1,3,3,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.505656E2,1.0004718E2,5.0518425E1,7.075093E1,2.9296247E1,6.4715886E0,4.4046837E1,6.6962944E1,3.787986E0,2.7823166E1,1.4730812E0,1.6349894E0,4.8365993E0,2.2302801E1,2.1744038E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[7.387525E-3,-3.6329985E-2,3.1401834E-1,3.922867E-2,-4.8384723E-1,-4.8578045E-1,1.9148001E-1,-3.420363E-2,7.310421E-2,-1.8171448E-1,1.0716105E-1,-2.429568E-2,-1.6776676E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":23,"left_children":[1,3,5,7,9,11,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.218499E0,4.921865E0,5.664933E0,3.9433513E0,2.3708324E0,1.4866972E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5],"right_children":[2,4,6,8,10,12,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.4425881E0,1.3956678E0,4.5153925E-1,1.0111414E0,1.2216467E0,-5.167576E-1,1.9148001E-1,-3.420363E-2,7.310421E-2,-1.8171448E-1,1.0716105E-1,-2.429568E-2,-1.6776676E-1],"split_indices":[2,0,0,3,2,1,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.6364516E2,1.4399655E2,1.9648613E1,1.2397671E2,2.0019842E1,5.5242333E0,1.4124379E1,7.1116135E1,5.286057E1,1.7730515E1,2.2893276E0,1.1712577E0,4.3529754E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.4467765E-2,2.080743E-1,-1.7292127E-1,-4.7704002E-1,2.9902416E-1,-2.0685807E-1,6.884434E-1,-7.605224E-3,-1.697991E-1,1.3227023E-1,-1.7589947E-2,-8.156897E-2,6.693411E-2,5.9669532E-2,2.4860035E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":24,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[4.736835E0,3.6256235E0,2.4345884E0,2.643479E-1,2.5610108E0,2.191857E0,5.507636E-2,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.406891E-2,-2.2962296E0,2.9755492E0,-1.2172549E0,9.893304E-1,1.8039737E0,1.7804539E-1,-7.605224E-3,-1.697991E-1,1.3227023E-1,-1.7589947E-2,-8.156897E-2,6.693411E-2,5.9669532E-2,2.4860035E-1],"split_indices":[3,1,0,3,1,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.3231004E2,5.489662E1,7.741342E1,5.9360414E0,4.896058E1,7.5210075E1,2.2033439E0,1.1881337E0,4.7479076E0,3.4888256E1,1.4072324E1,6.56477E1,9.562378E0,1.0290837E0,1.1742603E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[6.7778416E-3,-3.662384E-1,7.065992E-2,1.2101906E-1,-5.0399894E-1,3.104683E-1,-2.9183773E-2,-4.156303E-3,-1.6096625E-1,1.2665306E-1,-9.17492E-2,2.9581118E-2,-1.1535417E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":25,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[3.4348352E0,2.4382627E0,2.953601E0,0E0,2.8495455E-1,2.6238682E0,4.0220222E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.4058717E-1,-2.2962296E0,-8.278937E-1,1.2101906E-1,-2.600786E0,1.8622017E0,-5.490728E-1,-4.156303E-3,-1.6096625E-1,1.2665306E-1,-9.17492E-2,2.9581118E-2,-1.1535417E-1],"split_indices":[3,1,1,0,2,3,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.4205885E2,2.0079405E1,1.2197945E2,2.755703E0,1.7323702E1,3.5236805E1,8.674264E1,1.1677582E0,1.6155943E1,3.009657E1,5.1402345E0,6.4330086E1,2.2412552E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[6.1122016E-3,-3.0932928E-2,2.8709438E-1,3.4882322E-2,-4.5499435E-1,-4.5712167E-1,1.8004636E-1,-6.688488E-2,4.297451E-2,-1.7304227E-1,9.31752E-2,-1.6844256E-2,-1.6021039E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":26,"left_children":[1,3,5,7,9,11,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.6421589E0,3.9102547E0,4.555378E0,3.4327443E0,1.9340532E0,1.5531588E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5],"right_children":[2,4,6,8,10,12,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.4425881E0,1.3956678E0,4.5153925E-1,-8.182202E-1,1.2216467E0,-5.167576E-1,1.8004636E-1,-6.688488E-2,4.297451E-2,-1.7304227E-1,9.31752E-2,-1.6844256E-2,-1.6021039E-1],"split_indices":[2,0,0,1,2,1,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.5590987E2,1.385351E2,1.7374771E1,1.20717255E2,1.7817837E1,5.008424E0,1.2366348E1,3.5409992E1,8.530727E1,1.5584688E1,2.233149E0,1.0845095E0,3.9239142E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[-1.2841339E-2,3.8976517E-1,-7.814242E-2,-2.4121646E-2,4.3709022E-1,-4.9403644E-1,-2.5205292E-2,1.4776048E-1,-4.215324E-2,-1.6158266E-1,-4.3613804E-3,1.0649363E-1,-2.8236344E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0],"id":27,"left_children":[1,3,5,-1,7,9,11,-1,-1,-1,-1,-1,-1],"loss_changes":[3.3149943E0,4.2681456E-1,2.3725522E0,0E0,5.8634734E-1,2.524433E-1,2.5788443E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,4,4,5,5,6,6],"right_children":[2,4,6,-1,8,10,12,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.3936366E-1,-1.842745E0,-1.7600034E0,-2.4121646E-2,1.5365523E0,-5.051962E-1,-1.8138796E0,1.4776048E-1,-4.215324E-2,-1.6158266E-1,-4.3613804E-3,1.0649363E-1,-2.8236344E-2],"split_indices":[3,0,0,0,1,2,2,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2393182E2,1.660246E1,1.0732936E2,1.4543734E0,1.51480875E1,1.117803E1,9.615133E1,1.3955733E1,1.1923542E0,1.0114544E1,1.063486E0,1.4117768E1,8.203356E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"13","size_leaf_vector":"1"}},{"base_weights":[5.4487037E-3,1.499641E-1,-1.4004897E-1,-7.177455E-2,6.129975E-1,5.26862E-1,-2.9116732E-1,-4.130257E-2,1.8700607E-1,2.0334715E-1,-1.04372114E-1,-6.4618096E-2,2.0535553E-1,-1.618584E-1,-5.0461195E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":28,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.9090364E0,7.168543E0,7.10337E0,2.2762642E0,1.6420355E0,1.6798639E0,1.6458578E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-9.239956E-1,-7.7296084E-1,-1.0505564E0,2.6605442E0,1.850578E0,-3.530614E-1,5.4416007E-1,-4.130257E-2,1.8700607E-1,2.0334715E-1,-1.04372114E-1,-6.4618096E-2,2.0535553E-1,-1.618584E-1,-5.0461195E-2],"split_indices":[2,0,0,1,3,3,3,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.363511E2,6.839106E1,6.796004E1,4.681646E1,2.15746E1,1.2095245E1,5.5864796E1,4.3494926E1,3.3215344E0,2.0487062E1,1.0875387E0,2.0548177E0,1.0040428E1,1.73777E1,3.8487095E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[5.6611793E-3,3.773538E-2,-2.841661E-1,-3.950626E-2,3.0515108E-1,3.110922E-1,-5.7242596E-1,-4.867336E-2,1.1077137E-1,2.9061523E-1,-5.970359E-2,2.6886097E-1,-1.14180885E-1,-2.0935065E-1,7.355219E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":29,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.4161686E0,2.842019E0,2.7830777E0,5.437348E0,1.0503294E1,2.6079679E0,1.2559624E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.976097E0,1.2041476E0,1.2879819E-1,7.043773E-1,-4.8287207E-1,-1.1052026E0,5.9545773E-1,-4.867336E-2,1.1077137E-1,2.9061523E-1,-5.970359E-2,2.6886097E-1,-1.14180885E-1,-2.0935065E-1,7.355219E-2],"split_indices":[1,1,3,2,0,2,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.5018854E2,1.3600693E2,1.41816225E1,1.06187485E2,2.9819439E1,4.601105E0,9.580518E0,8.21288E1,2.4058685E1,1.247661E1,1.7342829E1,2.3324249E0,2.2686806E0,8.43342E0,1.1470972E0],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[1.4027715E-2,-2.2120595E-2,8.092999E-3]","boost_from_average":"1","num_class":"3","num_feature":"4","num_target":"1"},"objective":{"name":"multi:softprob","softmax_multiclass_param":{"num_class":"3"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_multiclass.py b/internal/xgjson/testdata/test_multiclass.py new file mode 100644 index 0000000..169184d --- /dev/null +++ b/internal/xgjson/testdata/test_multiclass.py @@ -0,0 +1,97 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates multi:softprob XGBoost test model (3-class). +Run with: uv run test_multiclass.py +Outputs: test_multiclass.json, test_multiclass.ubj, test_multiclass_expected.json + +NOTE on base_score: XGBoost 3.x stores a per-class base_score vector for multi-class models. +The leaves Go library (xgEnsemble) only supports a single scalar BaseScore, so it uses 0.0 +for multi-class. Expected values here are the raw tree sums (output_margin minus per-class +base_score) and the softmax of those sums, which is what the Go library actually computes. +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_classification +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def softmax(x): + e = np.exp(x - x.max(axis=1, keepdims=True)) + return e / e.sum(axis=1, keepdims=True) + + +def main(): + X, y = make_classification( + n_samples=600, + n_features=4, + n_classes=3, + n_informative=4, + n_redundant=0, + random_state=42, + ) + + X_train, X_test = X[:500], X[500:] + y_train = y[:500] + + model = xgb.XGBClassifier( + n_estimators=10, + max_depth=3, + objective="multi:softprob", + num_class=3, + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_multiclass.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_multiclass.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + # Parse per-class base_scores from the saved model. + # XGBoost 3.x stores these as "[v0,v1,...,vN]" in learner_model_param.base_score. + with open(json_path) as f: + saved = json.load(f) + bs_str = saved["learner"]["learner_model_param"]["base_score"].strip("[]") + base_scores = np.array([float(x) for x in bs_str.split(",")]) + print(f"Per-class base_scores: {base_scores.tolist()}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # output_margin=True gives per-class raw scores including per-class base_score + raw_preds = model.get_booster().predict(dtest, output_margin=True) # shape: [n, 3] + + # Go uses BaseScore=0 for multi-class (xgEnsemble doesn't support per-class base_score). + # Subtract per-class base_scores to get the pure tree sums that Go computes. + go_raw_preds = raw_preds - base_scores # broadcast over rows + go_probs = softmax(go_raw_preds) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_scores": go_raw_preds[i].tolist(), + "probabilities": go_probs[i].tolist(), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_multiclass_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" raw={[round(v,4) for v in e['raw_scores']]} prob={[round(v,4) for v in e['probabilities']]}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_multiclass.ubj b/internal/xgjson/testdata/test_multiclass.ubj new file mode 100644 index 0000000..4a2c619 Binary files /dev/null and b/internal/xgjson/testdata/test_multiclass.ubj differ diff --git a/internal/xgjson/testdata/test_multiclass_expected.json b/internal/xgjson/testdata/test_multiclass_expected.json new file mode 100644 index 0000000..69766d5 --- /dev/null +++ b/internal/xgjson/testdata/test_multiclass_expected.json @@ -0,0 +1,92 @@ +[ + { + "features": [ + 1.2898756580524573, + 0.9248578543855488, + -1.2090147492408292, + 1.1638837750179802 + ], + "raw_scores": [ + -1.1162556412542726, + 1.9363673925616456, + -0.1567775006479492 + ], + "probabilities": [ + 0.04035327308558963, + 0.8543112105802694, + 0.10533551633414107 + ] + }, + { + "features": [ + 0.14181489267063685, + -2.4478019338481474, + 0.812756653818766, + 1.1106575556359448 + ], + "raw_scores": [ + -0.5212679508055115, + -0.9247214793942261, + 0.7192503814512023 + ], + "probabilities": [ + 0.1951061845696842, + 0.13033270078894887, + 0.674561114641367 + ] + }, + { + "features": [ + 0.08734803051372408, + 0.9105387201543911, + 0.7767547146693972, + -0.7705641418952007 + ], + "raw_scores": [ + 1.6781063077127074, + -1.7592041492245483, + 0.2956783180051575 + ], + "probabilities": [ + 0.7793506344474285, + 0.0250569340114191, + 0.19559243154115238 + ] + }, + { + "features": [ + 1.2445468253247949, + 1.0397630318427356, + -1.1561191254665355, + 1.0761159326795893 + ], + "raw_scores": [ + -1.1162556412542726, + 1.9363673925616456, + -0.1567775006479492 + ], + "probabilities": [ + 0.04035327308558963, + 0.8543112105802694, + 0.10533551633414107 + ] + }, + { + "features": [ + -0.32039158134005596, + 2.703654404944862, + -1.9104368617091088, + 1.0037481848182237 + ], + "raw_scores": [ + -0.35270831015953064, + 1.8141070604541016, + -0.24331870629255675 + ], + "probabilities": [ + 0.0921996052286053, + 0.8049428003930158, + 0.10285759437837882 + ] + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/test_poisson.json b/internal/xgjson/testdata/test_poisson.json new file mode 100644 index 0000000..b26e740 --- /dev/null +++ b/internal/xgjson/testdata/test_poisson.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"10"},"iteration_indptr":[0,1,2,3,4,5,6,7,8,9,10],"tree_info":[0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[1.608796E-8,-1.6029176E-1,7E-1,-3.0662945E-1,8.970203E-2,2.1499042E-1,7E-1,-1.1313498E-1,-5.6605335E-2,-3.312177E-2,1.3046043E-1,4.193799E-3,1.391802E-1,2.1000001E-1,3.4098327E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.5571881E2,3.7720955E1,2.4467834E1,5.3337936E0,2.6378124E1,5.30269E0,9.165039E-1,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[8.384892E-1,3.4935063E-1,-2.5862848E-2,6.1704636E-2,2.172312E-1,1.3206617E0,1.8212988E0,-1.1313498E-1,-5.6605335E-2,-3.312177E-2,1.3046043E-1,4.193799E-3,1.391802E-1,2.1000001E-1,3.4098327E-2],"split_indices":[2,0,0,2,2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2364442E3,1.0293398E3,2.071044E2,6.491332E2,3.8020657E2,1.05097755E2,1.02006645E2,4.049355E2,2.4419772E2,2.4110661E2,1.3909998E2,5.87311E1,4.6366657E1,9.8915535E1,3.0911105E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[6.0249087E-3,-1.4485276E-1,6.292016E-1,-2.8846076E-1,7.249133E-2,1.7385581E-1,7E-1,-1.2007688E-1,-6.294832E-2,-2.9236889E-2,9.6497245E-2,-3.8690425E-2,7.703155E-2,2.1000001E-1,2.9422885E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.1504521E2,3.0768625E1,3.072609E1,5.132305E0,1.6661806E1,2.972922E0,1.0310974E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[8.384892E-1,3.4935063E-1,1.1451292E-2,-4.0526927E-1,2.172312E-1,1.0027578E0,1.8212988E0,-1.2007688E-1,-6.294832E-2,-2.9236889E-2,9.6497245E-2,-3.8690425E-2,7.703155E-2,2.1000001E-1,2.9422885E-2],"split_indices":[2,0,0,2,2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2216106E3,9.841138E2,2.3749681E2,5.9237836E2,3.9173547E2,1.16082E2,1.2141481E2,2.4292001E2,3.494583E2,2.3325153E2,1.5848395E2,2.483281E1,9.124919E1,1.1821648E2,3.1983294E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[8.249043E-3,-1.3017151E-1,4.8842916E-1,-2.6980826E-1,5.9279393E-2,1.3929035E-1,7E-1,-1.0460437E-1,-4.492878E-2,-2.764657E-2,7.164046E-2,7.21004E-2,-1.7957078E-2,1.4741372E-1,2.1000001E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[8.0968575E1,2.5032322E1,2.6119629E1,5.103985E0,1.0952632E1,2.4984605E0,4.2014847E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[8.384892E-1,3.4935063E-1,1.1451292E-2,6.1704636E-2,1.5704857E-1,2.0279728E-1,1.3591876E0,-1.0460437E-1,-4.492878E-2,-2.764657E-2,7.164046E-2,7.21004E-2,-1.7957078E-2,1.4741372E-1,2.1000001E-1],"split_indices":[2,0,0,2,2,1,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2162241E3,9.446428E2,2.7158136E2,5.435727E2,4.010701E2,1.22446434E2,1.4913492E2,3.268921E2,2.1668063E2,2.1781798E2,1.8325212E2,8.109383E1,4.1352604E1,1.0679396E2,4.234096E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[9.134995E-3,-1.2018245E-1,3.7619865E-1,-2.7463865E-1,2.4330707E-2,2.2345707E-1,7E-1,1.6753457E-2,-8.807992E-2,-8.588739E-2,3.03916E-2,-5.0609536E-2,7.969856E-2,2.1000001E-1,4.92894E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.772286E1,2.0079765E1,2.5428822E1,2.7408333E0,1.115958E1,4.354599E0,1.4625626E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[8.2798815E-1,1.599916E-1,1.3591876E0,-1.5338614E0,-9.5888263E-1,8.978154E-1,4.438127E-1,1.6753457E-2,-8.807992E-2,-8.588739E-2,3.03916E-2,-5.0609536E-2,7.969856E-2,2.1000001E-1,4.92894E-2],"split_indices":[2,0,0,1,2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2140918E3,8.982934E2,3.1579843E2,4.3377383E2,4.6451956E2,2.5939658E2,5.6401833E1,2.3431297E1,4.1034253E2,9.171107E1,3.7280847E2,2.491316E1,2.3448343E2,5.059793E1,5.8039017E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[9.09531E-3,-2.0343041E-1,1.8307982E-1,-2.3020588E-1,4.771158E-1,1.12988435E-1,7E-1,2.3115667E-2,-7.8853436E-2,5.2792482E-2,2.1000001E-1,-1.2131449E-2,6.406971E-2,2.1000001E-1,2.3619696E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[4.4970226E1,1.0085991E1,2.5175674E1,5.3096256E0,1.9333873E0,9.15984E0,7.2229195E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.344009E-1,1.9050602E0,1.5112284E0,-1.3204889E0,2.0931683E0,-5.119711E-2,4.438127E-1,2.3115667E-2,-7.8853436E-2,5.2792482E-2,2.1000001E-1,-1.2131449E-2,6.406971E-2,2.1000001E-1,2.3619696E-2],"split_indices":[2,0,0,1,0,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2141871E3,5.464805E2,6.677066E2,5.2643036E2,2.0050182E1,5.9203174E2,7.567487E1,5.042025E1,4.7601007E2,1.1633196E1,8.4169855E0,2.3466206E2,3.5736966E2,5.6691853E1,1.898302E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[8.181038E-3,-8.697243E-2,3.0043316E-1,-2.3265421E-1,2.8890641E-2,2.0781858E-1,7E-1,-1.0578607E-1,-4.8775777E-2,9.829865E-2,-7.159475E-3,1.6448982E-2,8.489892E-2,2.1000001E-1,3.727326E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.3880238E1,1.5521219E1,1.4152052E1,3.3662758E0,8.0914755E0,2.947919E0,1.9109726E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[9.5205367E-1,1.599916E-1,1.5112284E0,-4.0526927E-1,-1.1719716E0,-4.5647115E-1,3.8913205E-2,-1.0578607E-1,-4.8775777E-2,9.829865E-2,-7.159475E-3,1.6448982E-2,8.489892E-2,2.1000001E-1,3.727326E-2],"split_indices":[2,0,0,2,1,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2163882E3,9.181522E2,2.9823593E2,4.062903E2,5.1186188E2,2.5665982E2,4.157613E1,1.4831805E2,2.5797226E2,7.603583E1,4.3582605E2,8.513578E1,1.7152403E2,3.5333263E1,6.242867E0],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[7.0725908E-3,-1.7376871E-1,1.3126637E-1,-2.003981E-1,3.637069E-1,9.796992E-2,6.8071556E-1,-8.719458E-2,-2.7215052E-2,3.962268E-2,1.8250549E-1,-2.0084638E-2,4.5219928E-2,1.21616825E-1,2.1000001E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.7424751E1,7.1887817E0,1.3176699E1,4.669132E0,1.2697597E0,5.9593873E0,1.1783447E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.344009E-1,1.9050602E0,2.0769806E0,-4.0526927E-1,2.0931683E0,-3.8318732E-1,2.44513E0,-8.719458E-2,-2.7215052E-2,3.962268E-2,1.8250549E-1,-2.0084638E-2,4.5219928E-2,1.21616825E-1,2.1000001E-1],"split_indices":[2,0,0,2,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2190674E3,4.9617575E2,7.228916E2,4.7335062E2,2.2825144E1,6.826986E2,4.0193024E1,2.5877194E2,2.1457866E2,1.2515371E1,1.0309773E1,1.654118E2,5.172868E2,1.8398088E1,2.1794937E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[5.7086255E-3,-1.5682158E-1,1.0779569E-1,-2.0145491E-1,9.470294E-2,2.705563E-1,4.000411E-2,-8.737209E-2,-2.8699575E-2,4.7959937E-3,1.3134255E-1,2.499932E-2,1.3513407E-1,-2.7228395E-2,2.8268065E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.0315517E1,5.32216E0,8.286631E0,3.7857876E0,1.9340974E0,7.403927E0,3.7810888E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[2.344009E-1,1.3634287E0,-5.8487415E-1,-4.0526927E-1,2.0931683E0,5.4873884E-1,-7.971821E-2,-8.737209E-2,-2.8699575E-2,4.7959937E-3,1.3134255E-1,2.499932E-2,1.3513407E-1,-2.7228395E-2,2.8268065E-2],"split_indices":[2,0,1,2,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2223805E3,4.7137704E2,7.510034E2,4.0050586E2,7.0871185E1,2.1994563E2,5.310578E2,2.1569179E2,1.8481407E2,5.8497177E1,1.2374003E1,1.0849266E2,1.1145297E2,1.5546336E2,3.7559442E2],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.986343E-3,-6.855007E-2,1.8679157E-1,-1.9220346E-1,1.4377538E-2,3.6981606E-1,1.1181358E-1,3.1648107E-2,-6.5965526E-2,-6.833215E-2,1.5860908E-2,3.80832E-2,1.7194837E-1,1.7563814E-2,5.6632925E-2],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.6290684E1,9.022352E0,4.7478733E0,2.933526E0,4.928169E0,4.914747E0,1.0097888E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[9.5205367E-1,1.599916E-1,-5.7094365E-1,-1.2964723E0,-9.5888263E-1,-8.6334527E-1,1.5317789E0,3.1648107E-2,-6.5965526E-2,-6.833215E-2,1.5860908E-2,3.80832E-2,1.7194837E-1,1.7563814E-2,5.6632925E-2],"split_indices":[2,0,1,1,2,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2265896E3,8.785925E2,3.4799716E2,3.5216357E2,5.2642896E2,9.998851E1,2.4800864E2,2.9721193E1,3.2244238E2,7.152883E1,4.549001E2,4.6305855E1,5.368266E1,1.4760732E2,1.0040132E2],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[2.6625937E-3,-1.5288876E-1,6.671782E-2,-1.8247098E-1,2.5947097E-1,1.1514038E-2,1.8340808E-1,8.038846E-2,-6.2043E-2,4.460892E-2,1.2815052E-1,-1.8337945E-2,3.2311738E-2,6.9908835E-2,-2.6620878E-3],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.2274185E1,4.430501E0,5.618476E0,3.7374134E0,3.9095032E-1,4.1527863E0,2.6750078E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[3.0631492E-2,1.9050602E0,8.705556E-1,-2.0729835E0,-6.9893044E-1,8.978154E-1,4.438127E-1,8.038846E-2,-6.2043E-2,4.460892E-2,1.2815052E-1,-1.8337945E-2,3.2311738E-2,6.9908835E-2,-2.6620878E-3],"split_indices":[2,0,0,1,2,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[1.2298446E3,3.5831866E2,8.7152594E2,3.3485413E2,2.3464561E1,5.9238116E2,2.7914478E2,1.6653053E1,3.1820105E2,1.5266671E1,8.19789E0,3.3771536E2,2.5466583E2,2.217156E2,5.7429165E1],"tree_param":{"num_deleted":"0","num_feature":"3","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[1.535E0]","boost_from_average":"1","num_class":"0","num_feature":"3","num_target":"1"},"objective":{"name":"count:poisson","poisson_regression_param":{"max_delta_step":"0.699999988"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_poisson.py b/internal/xgjson/testdata/test_poisson.py new file mode 100644 index 0000000..44d2a1a --- /dev/null +++ b/internal/xgjson/testdata/test_poisson.py @@ -0,0 +1,69 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates count:poisson XGBoost test model. +Run with: uv run test_poisson.py +Outputs: test_poisson.json, test_poisson.ubj, test_poisson_expected.json +""" + +import json +import os +import numpy as np +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + rng = np.random.default_rng(42) + X = rng.standard_normal((500, 3)) + y = rng.poisson(np.exp(X @ [0.5, -0.3, 0.8])) + + X_train, X_test = X[:400], X[400:] + y_train = y[:400] + + model = xgb.XGBRegressor( + n_estimators=10, + max_depth=3, + objective="count:poisson", + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_poisson.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_poisson.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # output_margin=True gives log-space raw scores + raw_preds = model.get_booster().predict(dtest, output_margin=True) + # predict() gives exp(margin) = count prediction + count_preds = model.predict(test_rows) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + "prediction": float(count_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_poisson_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f} pred={e['prediction']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_poisson.ubj b/internal/xgjson/testdata/test_poisson.ubj new file mode 100644 index 0000000..c70dae2 Binary files /dev/null and b/internal/xgjson/testdata/test_poisson.ubj differ diff --git a/internal/xgjson/testdata/test_poisson_expected.json b/internal/xgjson/testdata/test_poisson_expected.json new file mode 100644 index 0000000..6a23356 --- /dev/null +++ b/internal/xgjson/testdata/test_poisson_expected.json @@ -0,0 +1,47 @@ +[ + { + "features": [ + -1.322541133291679, + -0.4861941306990469, + 0.42022670947458574 + ], + "raw_score": -0.01655573584139347, + "prediction": 0.9835805296897888 + }, + { + "features": [ + -0.10239670661109637, + -0.6505636204055921, + -0.674212461724509 + ], + "raw_score": -0.48458048701286316, + "prediction": 0.6159555315971375 + }, + { + "features": [ + -0.7123370644127087, + -0.8795096286589426, + 2.2816328772760617 + ], + "raw_score": 0.8761682510375977, + "prediction": 2.401679515838623 + }, + { + "features": [ + 0.2975110711855549, + 0.8867590544047386, + -0.4890774681189237 + ], + "raw_score": -0.18565593659877777, + "prediction": 0.8305593132972717 + }, + { + "features": [ + -0.18595667718604886, + -0.713553981370355, + -2.651708334919202 + ], + "raw_score": -0.48458048701286316, + "prediction": 0.6159555315971375 + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/test_regression.json b/internal/xgjson/testdata/test_regression.json new file mode 100644 index 0000000..825add4 --- /dev/null +++ b/internal/xgjson/testdata/test_regression.json @@ -0,0 +1 @@ +{"learner":{"attributes":{},"feature_names":[],"feature_types":[],"gradient_booster":{"model":{"cats":{"enc":[],"feature_segments":[],"sorted_idx":[]},"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"10"},"iteration_indptr":[0,1,2,3,4,5,6,7,8,9,10],"tree_info":[0,0,0,0,0,0,0,0,0,0],"trees":[{"base_weights":[-3.4900674E-7,-2.1703766E1,3.888591E1,-5.660842E1,1.1113851E0,1.8131002E1,6.76406E1,-2.410375E1,-1.3382771E1,-2.6366928E0,1.2017798E1,-1.8681777E0,9.061715E0,1.3893869E1,2.3894E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":0,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.3927622E5,2.055225E5,8.471297E4,2.5468469E4,6.0891074E4,2.5188213E4,1.1675E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[3.3198032E-1,-2.0812225E-1,1.9204912E-1,-7.2713715E-1,1.2897527E0,-7.728777E-1,6.982233E-1,-2.410375E1,-1.3382771E1,-2.6366928E0,1.2017798E1,-1.8681777E0,9.061715E0,1.3893869E1,2.3894E1],"split_indices":[1,0,0,1,0,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.57E2,1.43E2,1.01E2,1.56E2,8.4E1,5.9E1,3.2E1,6.9E1,1.25E2,3.1E1,2.8E1,5.6E1,2.3E1,3.6E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[4.208488E-2,-3.1665228E1,1.6616203E1,-4.209725E1,-9.3845826E-1,-1.0655563E1,3.008117E1,-1.8151125E1,-9.645155E0,-1.0015606E1,2.8573256E0,-7.5468655E0,3.6063957E0,6.482903E0,1.6121021E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":1,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.1125997E5,4.4195344E4,9.7264875E4,1.69125E4,1.2539539E4,2.9239414E4,3.4309188E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.357847E-1,6.725737E-1,-2.5497723E-1,-6.095122E-1,-1.3044695E0,6.4537597E-1,8.700677E-1,-1.8151125E1,-9.645155E0,-1.0015606E1,2.8573256E0,-7.5468655E0,3.6063957E0,6.482903E0,1.6121021E1],"split_indices":[1,0,0,0,1,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.37E2,2.63E2,1.02E2,3.5E1,8.7E1,1.76E2,3.4E1,6.8E1,8E0,2.7E1,5.3E1,3.4E1,1.31E2,4.5E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.1167165E-2,-2.2418953E1,1.1702036E1,-5.0537853E1,-1.5123926E1,-1.5157427E0,2.8579166E1,-1.5612584E1,-1.4890947E0,-7.604434E0,1.3045177E-1,-1.0831369E1,9.6371937E-1,5.392886E0,1.2125985E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":2,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0551161E5,2.754336E4,5.893595E4,1.6610234E3,1.7668422E4,2.4483562E4,1.3836781E4,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-3.357847E-1,-1.3719012E0,1.9204912E-1,1.2897527E0,2.5049284E-1,-1.5255252E0,8.004095E-1,-1.5612584E1,-1.4890947E0,-7.604434E0,1.3045177E-1,-1.0831369E1,9.6371937E-1,5.392886E0,1.2125985E1],"split_indices":[1,1,0,0,0,0,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.37E2,2.63E2,2.7E1,1.1E2,1.48E2,1.15E2,2.6E1,1E0,6.6E1,4.4E1,1.7E1,1.31E2,6.2E1,5.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-5.002281E-2,-7.8147173E0,2.0562166E1,-2.491751E1,-3.1489772E-1,1.1856391E0,2.5955784E1,-9.092211E0,4.703793E-1,-4.403968E0,1.952447E0,-5.087349E0,2.3423016E0,5.6828566E0,1.2984971E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":3,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[6.4339684E4,3.7446367E4,1.1465273E4,1.2752523E4,2.009022E4,3.1360671E3,9.630484E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.5979043E-1,-4.6341768E-1,-7.539646E-1,9.1786194E-1,-6.414816E-1,-5.100164E-1,1.3993554E0,-9.092211E0,4.703793E-1,-4.403968E0,1.952447E0,-5.087349E0,2.3423016E0,5.6828566E0,1.2984971E1],"split_indices":[1,0,0,3,1,3,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.91E2,1.09E2,8.8E1,2.03E2,2.4E1,8.5E1,7.3E1,1.5E1,6.5E1,1.38E2,6E0,1.8E1,6.2E1,2.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-2.2922823E-2,-8.812916E0,9.98283E0,-1.2287054E1,2.057175E1,3.8232608E0,2.4433542E1,-9.349065E0,-2.6622648E0,-7.8267634E-1,7.2537894E0,2.2643733E-1,8.955402E0,5.292737E0,1.1337043E1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":4,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[3.535613E4,2.209925E4,1.6640496E4,1.2092553E4,1.9864639E3,1.0599696E4,4.4128164E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.042011E-1,1.5033983E0,5.6091946E-1,-9.983854E-1,-1.2239403E0,1.6777008E0,1.1581109E0,-9.349065E0,-2.6622648E0,-7.8267634E-1,7.2537894E0,2.2643733E-1,8.955402E0,5.292737E0,1.1337043E1],"split_indices":[1,0,0,0,2,1,0,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.13E2,1.87E2,1.91E2,2.2E1,1.32E2,5.5E1,2.8E1,1.63E2,3E0,1.9E1,1.19E2,1.3E1,3.8E1,1.7E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[3.450642E-2,-7.6010866E0,7.9013066E0,-2.0408184E1,-2.7366667E0,-1.424733E1,1.0195322E1,-1.4414116E1,-5.0269175E0,-2.0192778E0,4.2070208E0,-5.7176957E0,-9.828203E-1,-1.6186364E0,4.080827E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":5,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[2.41472E4,1.2653163E4,1.0205502E4,4.8470625E3,1.0069055E4,9.405698E2,9.635811E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[5.9218433E-2,-5.449191E-1,-1.2608839E0,-2.1356742E0,8.696059E-1,1.943843E-1,-8.254972E-1,-1.4414116E1,-5.0269175E0,-2.0192778E0,4.2070208E0,-5.7176957E0,-9.828203E-1,-1.6186364E0,4.080827E0],"split_indices":[3,1,0,1,0,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.03E2,1.97E2,5.5E1,1.48E2,1.8E1,1.79E2,5E0,5E1,1.2E2,2.8E1,1.2E1,6E0,3.2E1,1.47E2],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-2.9440423E-2,-5.389619E0,6.9236445E0,-7.6967645E0,1.5408404E1,-1.2216392E1,8.663874E0,-4.9256887E0,-1.1134292E0,-1.1573362E0,5.547367E0,-5.031951E0,1.3524102E0,1.9145297E0,6.08035E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":6,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.4982496E4,1.10109795E4,5.93475E3,7.065041E3,1.4374014E3,1.2187563E3,4.134176E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.504189E-1,1.3902075E0,-1.2515395E0,-5.100164E-1,-1.51937E0,1.3071427E0,9.633761E-1,-4.9256887E0,-1.1134292E0,-1.1573362E0,5.547367E0,-5.031951E0,1.3524102E0,1.9145297E0,6.08035E0],"split_indices":[0,1,1,3,0,2,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,2.26E2,1.74E2,2.04E2,2.2E1,1.4E1,1.6E2,6.3E1,1.41E2,3E0,1.9E1,1.1E1,3E0,1.35E2,2.5E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-1.395906E-2,-2.2662601E1,1.1718702E0,-2.7021698E1,-7.9877605E0,-1.3520322E0,1.1025952E1,-9.467001E0,-4.487939E0,-3.1854148E0,7.7455217E-1,-1.2673126E0,2.162385E0,1.5491321E0,4.4347887E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":7,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[1.0796384E4,1.0635391E3,9.515531E3,3.4842383E2,1.9422064E2,7.529547E3,1.6414043E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[-1.5566292E0,5.8392817E-1,9.1786194E-1,-1.7131345E0,1.521316E0,6.725737E-1,-3.426876E-1,-9.467001E0,-4.487939E0,-3.1854148E0,7.7455217E-1,-1.2673126E0,2.162385E0,1.5491321E0,4.4347887E0],"split_indices":[1,0,3,1,2,0,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,1.9E1,3.81E2,1.4E1,5E0,3.04E2,7.7E1,9E0,5E0,4E0,1E0,2.28E2,7.6E1,3.1E1,4.6E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-3.3429414E-2,-2.614427E0,7.8663254E0,-1.4629532E1,-7.3140913E-1,-1.6345493E1,9.057828E0,-6.7656684E0,-3.291841E0,-1.1315507E0,1.3278493E0,-5.9210744E0,-4.1697222E-1,1.190232E0,3.8547845E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":8,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[8.196654E3,6.8445713E3,3.0040518E3,9.4035547E2,4.140856E3,2.261709E2,1.7825186E3,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[6.454842E-1,-1.0676204E0,-1.6685841E0,-8.0227727E-1,3.8240975E-1,1.0666747E0,-1.0594835E-1,-6.7656684E0,-3.291841E0,-1.1315507E0,1.3278493E0,-5.9210744E0,-4.1697222E-1,1.190232E0,3.8547845E0],"split_indices":[1,0,3,3,3,1,2,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,3.02E2,9.8E1,4E1,2.62E2,4E0,9.4E1,1.1E1,2.9E1,1.65E2,9.7E1,3E0,1E0,4.1E1,5.3E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}},{"base_weights":[-5.1216763E-2,-9.203662E-1,1.65521E1,-4.6210365E0,2.68967E0,1.094314E0,1.9215572E1,-4.8870697E0,-9.8923075E-1,-6.7744327E-1,1.7080133E0,-1.0934024E0,1.1666605E0,2.6789975E0,6.5080333E0],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"id":9,"left_children":[1,3,5,7,9,11,13,-1,-1,-1,-1,-1,-1,-1,-1],"loss_changes":[5.8019707E3,5.1157793E3,8.023999E2,2.8654568E3,2.8960432E3,6.71471E1,2.3954541E2,0E0,0E0,0E0,0E0,0E0,0E0,0E0,0E0],"parents":[2147483647,0,0,1,1,2,2,3,3,4,4,5,5,6,6],"right_children":[2,4,6,8,10,12,14,-1,-1,-1,-1,-1,-1,-1,-1],"split_conditions":[1.6420152E0,4.886007E-2,-1.279577E0,-1.5625459E0,-2.8865865E-1,-1.0372461E0,-5.642476E-1,-4.8870697E0,-9.8923075E-1,-6.7744327E-1,1.7080133E0,-1.0934024E0,1.1666605E0,2.6789975E0,6.5080333E0],"split_indices":[0,1,2,1,2,1,1,0,0,0,0,0,0,0,0],"split_type":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"sum_hessian":[4E2,3.81E2,1.9E1,1.88E2,1.93E2,3E0,1.6E1,1.8E1,1.7E2,7.3E1,1.2E2,1E0,2E0,4E0,1.2E1],"tree_param":{"num_deleted":"0","num_feature":"4","num_nodes":"15","size_leaf_vector":"1"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"[3.5165477E0]","boost_from_average":"1","num_class":"0","num_feature":"4","num_target":"1"},"objective":{"name":"reg:squarederror","reg_loss_param":{"scale_pos_weight":"1"}}},"version":[3,2,0]} \ No newline at end of file diff --git a/internal/xgjson/testdata/test_regression.py b/internal/xgjson/testdata/test_regression.py new file mode 100644 index 0000000..5a4629f --- /dev/null +++ b/internal/xgjson/testdata/test_regression.py @@ -0,0 +1,65 @@ +# /// script +# dependencies = ["xgboost>=2.0", "numpy", "scikit-learn"] +# /// +""" +Generates reg:squarederror XGBoost test model. +Run with: uv run test_regression.py +Outputs: test_regression.json, test_regression.ubj, test_regression_expected.json +""" + +import json +import os +import numpy as np +from sklearn.datasets import make_regression +import xgboost as xgb + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def main(): + X, y = make_regression(n_samples=500, n_features=4, random_state=42) + + X_train, X_test = X[:400], X[400:] + y_train = y[:400] + + model = xgb.XGBRegressor( + n_estimators=10, + max_depth=3, + objective="reg:squarederror", + random_state=42, + ) + model.fit(X_train, y_train) + + json_path = os.path.join(SCRIPT_DIR, "test_regression.json") + ubj_path = os.path.join(SCRIPT_DIR, "test_regression.ubj") + model.get_booster().save_model(json_path) + model.get_booster().save_model(ubj_path) + print(f"Saved {json_path}") + print(f"Saved {ubj_path}") + + test_rows = X_test[:5] + dtest = xgb.DMatrix(test_rows) + # identity link: output_margin == predict() + raw_preds = model.get_booster().predict(dtest, output_margin=True) + + expected = [] + for i in range(len(test_rows)): + expected.append( + { + "features": test_rows[i].tolist(), + "raw_score": float(raw_preds[i]), + } + ) + + expected_path = os.path.join(SCRIPT_DIR, "test_regression_expected.json") + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved {expected_path}") + + print(f"\nXGBoost version: {xgb.__version__}") + for e in expected: + print(f" features={[round(v,4) for v in e['features']]} raw={e['raw_score']:.6f}") + + +if __name__ == "__main__": + main() diff --git a/internal/xgjson/testdata/test_regression.ubj b/internal/xgjson/testdata/test_regression.ubj new file mode 100644 index 0000000..59db117 Binary files /dev/null and b/internal/xgjson/testdata/test_regression.ubj differ diff --git a/internal/xgjson/testdata/test_regression_expected.json b/internal/xgjson/testdata/test_regression_expected.json new file mode 100644 index 0000000..2a50522 --- /dev/null +++ b/internal/xgjson/testdata/test_regression_expected.json @@ -0,0 +1,47 @@ +[ + { + "features": [ + -2.123895724309807, + -0.8397218421807761, + -0.5993926454440222, + -0.525755021680761 + ], + "raw_score": -80.28502655029297 + }, + { + "features": [ + 0.7326400772155792, + 0.7276295436369798, + 0.05194588580729943, + -0.08071658010858232 + ], + "raw_score": 57.882362365722656 + }, + { + "features": [ + 2.4457519796168263, + 0.2799686263198203, + -1.1254890472983765, + 0.1292211819752275 + ], + "raw_score": 60.96694564819336 + }, + { + "features": [ + -0.3653215513121087, + 2.013387247526623, + 0.13653533108273744, + 0.1846803058649084 + ], + "raw_score": 53.012428283691406 + }, + { + "features": [ + 0.19652116970147013, + 0.6427227598675439, + 1.3291525301324314, + 0.7090037575885123 + ], + "raw_score": 42.95940399169922 + } +] \ No newline at end of file diff --git a/internal/xgjson/testdata/train.py b/internal/xgjson/testdata/train.py new file mode 100644 index 0000000..fc8cecc --- /dev/null +++ b/internal/xgjson/testdata/train.py @@ -0,0 +1,21 @@ +# /// script +# dependencies = [] +# /// +""" +Coordinator: runs all per-case training scripts. +Run with: uv run train.py +""" + +import subprocess +import pathlib + +here = pathlib.Path(__file__).parent + +for s in [ + "test_binary_logistic.py", + "test_regression.py", + "test_multiclass.py", + "test_poisson.py", +]: + print(f"\n=== Running {s} ===") + subprocess.run(["uv", "run", str(here / s)], check=True) diff --git a/internal/xgjson/xgjson_io.go b/internal/xgjson/xgjson_io.go new file mode 100644 index 0000000..df3f5cc --- /dev/null +++ b/internal/xgjson/xgjson_io.go @@ -0,0 +1,468 @@ +// Package xgjson reads XGBoost models saved in JSON (.json) and Universal Binary JSON (.ubj) formats. +// These are the modern formats produced by XGBoost 2.x and 3.x. +// +// Use ReadModelJSON for .json files and ReadModelUBJ for .ubj files. +// Both return *ModelJSON with the same logical structure. +package xgjson + +import ( + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/dmitryikh/leaves/internal/ubjdecode" +) + +// ModelJSON is the top-level XGBoost model structure. +type ModelJSON struct { + Learner LearnerJSON + Version []int32 +} + +// LearnerJSON holds the learner sub-tree of the model. +type LearnerJSON struct { + FeatureNames []string + GradientBooster GradientBoosterJSON + LearnerModelParam LearnerModelParamJSON + Objective ObjectiveJSON +} + +// LearnerModelParamJSON holds global model parameters. +// Numeric fields are stored as strings (or bracketed strings like "[5.2E-1]") by XGBoost. +type LearnerModelParamJSON struct { + BaseScore string // may be "[value]" in XGBoost 2.x + NumClass string + NumFeature string +} + +// ObjectiveJSON holds the objective function name. +type ObjectiveJSON struct { + Name string +} + +// GradientBoosterJSON wraps the booster model and its name. +type GradientBoosterJSON struct { + Model GBTreeModelDataJSON + Name string // "gbtree" or "dart" +} + +// GBTreeModelDataJSON holds the tree ensemble data. +type GBTreeModelDataJSON struct { + GBTreeModelParam GBTreeModelParamJSON + TreeInfo []int32 + Trees []TreeJSON +} + +// GBTreeModelParamJSON holds ensemble-level parameters. +type GBTreeModelParamJSON struct { + NumParallelTree string + NumTrees string +} + +// TreeJSON holds a single decision tree's data. +type TreeJSON struct { + ID int32 + BaseWeights []float64 + DefaultLeft []int32 + LeftChildren []int32 + RightChildren []int32 + SplitConditions []float64 + SplitIndices []int32 + SplitType []int32 + NumNodes string // from tree_param.num_nodes +} + +// --------------------------------------------------------------------------- +// JSON format (encoding/json with struct tags) +// --------------------------------------------------------------------------- + +// jsonModel is the raw JSON-tagged equivalent used only for unmarshaling. +type jsonModel struct { + Learner struct { + FeatureNames []string `json:"feature_names"` + GradientBooster struct { + Model struct { + GBTreeModelParam struct { + NumParallelTree string `json:"num_parallel_tree"` + NumTrees string `json:"num_trees"` + } `json:"gbtree_model_param"` + TreeInfo []int32 `json:"tree_info"` + Trees []struct { + ID int32 `json:"id"` + BaseWeights []float64 `json:"base_weights"` + DefaultLeft []int32 `json:"default_left"` + LeftChildren []int32 `json:"left_children"` + RightChildren []int32 `json:"right_children"` + SplitConditions []float64 `json:"split_conditions"` + SplitIndices []int32 `json:"split_indices"` + SplitType []int32 `json:"split_type"` + TreeParam struct { + NumNodes string `json:"num_nodes"` + } `json:"tree_param"` + } `json:"trees"` + } `json:"model"` + Name string `json:"name"` + } `json:"gradient_booster"` + LearnerModelParam struct { + BaseScore string `json:"base_score"` + NumClass string `json:"num_class"` + NumFeature string `json:"num_feature"` + } `json:"learner_model_param"` + Objective struct { + Name string `json:"name"` + } `json:"objective"` + } `json:"learner"` + Version []int32 `json:"version"` +} + +// ReadModelJSON reads an XGBoost model from a JSON reader. +func ReadModelJSON(r io.Reader) (*ModelJSON, error) { + var raw jsonModel + if err := json.NewDecoder(r).Decode(&raw); err != nil { + return nil, fmt.Errorf("xgjson: JSON decode: %w", err) + } + return jsonModelToModelJSON(&raw), nil +} + +func jsonModelToModelJSON(raw *jsonModel) *ModelJSON { + m := &ModelJSON{ + Version: raw.Version, + } + rl := &raw.Learner + m.Learner.FeatureNames = rl.FeatureNames + m.Learner.LearnerModelParam = LearnerModelParamJSON{ + BaseScore: rl.LearnerModelParam.BaseScore, + NumClass: rl.LearnerModelParam.NumClass, + NumFeature: rl.LearnerModelParam.NumFeature, + } + m.Learner.Objective = ObjectiveJSON{Name: rl.Objective.Name} + + gb := &rl.GradientBooster + m.Learner.GradientBooster = GradientBoosterJSON{ + Name: gb.Name, + Model: GBTreeModelDataJSON{ + GBTreeModelParam: GBTreeModelParamJSON{ + NumParallelTree: gb.Model.GBTreeModelParam.NumParallelTree, + NumTrees: gb.Model.GBTreeModelParam.NumTrees, + }, + TreeInfo: gb.Model.TreeInfo, + }, + } + + for _, t := range gb.Model.Trees { + m.Learner.GradientBooster.Model.Trees = append( + m.Learner.GradientBooster.Model.Trees, + TreeJSON{ + ID: t.ID, + BaseWeights: t.BaseWeights, + DefaultLeft: t.DefaultLeft, + LeftChildren: t.LeftChildren, + RightChildren: t.RightChildren, + SplitConditions: t.SplitConditions, + SplitIndices: t.SplitIndices, + SplitType: t.SplitType, + NumNodes: t.TreeParam.NumNodes, + }, + ) + } + return m +} + +// --------------------------------------------------------------------------- +// UBJ format — walk map[string]interface{} from ubjdecode +// --------------------------------------------------------------------------- + +// ReadModelUBJ reads an XGBoost model from a UBJ reader. +// Uses ubjdecode internally; produces the same ModelJSON struct. +func ReadModelUBJ(r io.Reader) (*ModelJSON, error) { + raw, err := ubjdecode.DecodeValue(r) + if err != nil { + return nil, fmt.Errorf("xgjson: UBJ decode: %w", err) + } + root, ok := raw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ root is not an object") + } + return ubjMapToModelJSON(root) +} + +// --------------------------------------------------------------------------- +// UBJ map-walking helpers +// --------------------------------------------------------------------------- + +func ubjMapToModelJSON(root map[string]interface{}) (*ModelJSON, error) { + m := &ModelJSON{} + + if v, ok := root["version"]; ok { + m.Version = toInt32Slice(v) + } + + learnerRaw, ok := root["learner"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ missing 'learner' key") + } + learnerMap, ok := learnerRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'learner' is not an object") + } + + // feature_names + if v, ok := learnerMap["feature_names"]; ok { + m.Learner.FeatureNames = toStringSlice(v) + } + + // learner_model_param + if v, ok := learnerMap["learner_model_param"]; ok { + if mp, ok := v.(map[string]interface{}); ok { + m.Learner.LearnerModelParam = LearnerModelParamJSON{ + BaseScore: mapString(mp, "base_score"), + NumClass: mapString(mp, "num_class"), + NumFeature: mapString(mp, "num_feature"), + } + } + } + + // objective + if v, ok := learnerMap["objective"]; ok { + if om, ok := v.(map[string]interface{}); ok { + m.Learner.Objective = ObjectiveJSON{Name: mapString(om, "name")} + } + } + + // gradient_booster + if v, ok := learnerMap["gradient_booster"]; ok { + gbMap, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'gradient_booster' is not an object") + } + gb, err := ubjGradientBooster(gbMap) + if err != nil { + return nil, err + } + m.Learner.GradientBooster = *gb + } + + return m, nil +} + +func ubjGradientBooster(gbMap map[string]interface{}) (*GradientBoosterJSON, error) { + gb := &GradientBoosterJSON{ + Name: mapString(gbMap, "name"), + } + + modelRaw, ok := gbMap["model"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ gradient_booster missing 'model'") + } + modelMap, ok := modelRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ gradient_booster 'model' is not an object") + } + + // gbtree_model_param + if v, ok := modelMap["gbtree_model_param"]; ok { + if pm, ok := v.(map[string]interface{}); ok { + gb.Model.GBTreeModelParam = GBTreeModelParamJSON{ + NumParallelTree: mapString(pm, "num_parallel_tree"), + NumTrees: mapString(pm, "num_trees"), + } + } + } + + // tree_info + if v, ok := modelMap["tree_info"]; ok { + gb.Model.TreeInfo = toInt32Slice(v) + } + + // trees + treesRaw, ok := modelMap["trees"] + if !ok { + return nil, fmt.Errorf("xgjson: UBJ model missing 'trees'") + } + treesSlice, ok := treesRaw.([]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ 'trees' is not an array") + } + for i, tRaw := range treesSlice { + tMap, ok := tRaw.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xgjson: UBJ tree %d is not an object", i) + } + tree, err := ubjTree(tMap) + if err != nil { + return nil, fmt.Errorf("xgjson: UBJ tree %d: %w", i, err) + } + gb.Model.Trees = append(gb.Model.Trees, *tree) + } + + return gb, nil +} + +func ubjTree(tMap map[string]interface{}) (*TreeJSON, error) { + t := &TreeJSON{} + + if v, ok := tMap["id"]; ok { + t.ID = int32(toInt64(v)) + } + t.BaseWeights = toFloat64Slice(tMap["base_weights"]) + t.DefaultLeft = toInt32Slice(tMap["default_left"]) + t.LeftChildren = toInt32Slice(tMap["left_children"]) + t.RightChildren = toInt32Slice(tMap["right_children"]) + t.SplitConditions = toFloat64Slice(tMap["split_conditions"]) + t.SplitIndices = toInt32Slice(tMap["split_indices"]) + t.SplitType = toInt32Slice(tMap["split_type"]) + + if v, ok := tMap["tree_param"]; ok { + if pm, ok := v.(map[string]interface{}); ok { + t.NumNodes = mapString(pm, "num_nodes") + } + } + return t, nil +} + +// --------------------------------------------------------------------------- +// Type coercion helpers +// --------------------------------------------------------------------------- + +func mapString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + return fmt.Sprintf("%v", v) + } + return "" +} + +func toInt64(v interface{}) int64 { + switch x := v.(type) { + case int64: + return x + case int32: + return int64(x) + case float64: + return int64(x) + case float32: + return int64(x) + default: + return 0 + } +} + +// toInt32Slice coerces any of the UBJ typed array forms to []int32. +func toInt32Slice(v interface{}) []int32 { + if v == nil { + return nil + } + switch x := v.(type) { + case []int32: + return x + case []int64: + out := make([]int32, len(x)) + for i, vv := range x { + out[i] = int32(vv) + } + return out + case []interface{}: + out := make([]int32, len(x)) + for i, vv := range x { + out[i] = int32(toInt64(vv)) + } + return out + default: + return nil + } +} + +// toFloat64Slice coerces any of the UBJ typed array forms to []float64. +func toFloat64Slice(v interface{}) []float64 { + if v == nil { + return nil + } + switch x := v.(type) { + case []float64: + return x + case []float32: + out := make([]float64, len(x)) + for i, vv := range x { + out[i] = float64(vv) + } + return out + case []interface{}: + out := make([]float64, len(x)) + for i, vv := range x { + switch f := vv.(type) { + case float64: + out[i] = f + case float32: + out[i] = float64(f) + case int64: + out[i] = float64(f) + } + } + return out + default: + return nil + } +} + +func toStringSlice(v interface{}) []string { + if v == nil { + return nil + } + switch x := v.(type) { + case []string: + return x + case []interface{}: + out := make([]string, len(x)) + for i, vv := range x { + if s, ok := vv.(string); ok { + out[i] = s + } else { + out[i] = fmt.Sprintf("%v", vv) + } + } + return out + default: + return nil + } +} + +// ParseBaseScore parses XGBoost's base_score string, which may be wrapped in +// brackets like "[5.2313304E-1]" (XGBoost 2.x) or a plain float string. +// For multi-class models (XGBoost 3.x) with comma-separated per-class values, +// use ParseBaseScoreMulti instead. +func ParseBaseScore(s string) (float64, error) { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + s = strings.TrimSpace(s) + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, fmt.Errorf("xgjson: parse base_score %q: %w", s, err) + } + return v, nil +} + +// ParseBaseScoreMulti parses XGBoost's base_score string, returning one float64 +// per output group. Handles both single-value (all groups share same score) and +// comma-separated per-group strings like "[1.4E-2,-2.2E-2,8.1E-3]" from XGBoost 3.x +// multi-class models. +func ParseBaseScoreMulti(s string) ([]float64, error) { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "[") + s = strings.TrimSuffix(s, "]") + s = strings.TrimSpace(s) + parts := strings.Split(s, ",") + out := make([]float64, len(parts)) + for i, p := range parts { + v, err := strconv.ParseFloat(strings.TrimSpace(p), 64) + if err != nil { + return nil, fmt.Errorf("xgjson: parse base_score[%d] in %q: %w", i, s, err) + } + out[i] = v + } + return out, nil +} diff --git a/xgensemble_io.go b/xgensemble_io.go index 5102dfc..59d68b5 100644 --- a/xgensemble_io.go +++ b/xgensemble_io.go @@ -4,8 +4,11 @@ import ( "bufio" "fmt" "os" + "path/filepath" + "strings" "github.com/dmitryikh/leaves/internal/xgbin" + "github.com/dmitryikh/leaves/internal/xgjson" "github.com/dmitryikh/leaves/transformation" ) @@ -248,7 +251,24 @@ func XGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem return &Ensemble{e, transform}, nil } -// XGEnsembleFromFile reads XGBoost model from binary file. Works with 'gbtree' and 'dart' models +// XGEnsembleFromFile reads an XGBoost model from a file. +// +// Format detection uses both the file extension and a content heuristic so +// that misnamed files don't silently produce wrong results and to preserve +// backward compatibility with existing callers: +// +// - .json → content-verified JSON; falls through to legacy binary if +// content doesn't look like XGBoost JSON +// - .ubj/.ubjson → content-verified UBJ; falls through to legacy binary if +// content doesn't look like XGBoost UBJ +// - any other → falls through directly to legacy binary reader +// +// Note: for unknown extensions, we intentionally do NOT run content detection +// and fall straight through to the legacy binary reader. This preserves the +// existing behavior for all callers who weren't using .json/.ubj extensions. +// The trade-off is that a JSON/UBJ file named e.g. "model.bin" won't be +// auto-detected; callers should use XGEnsembleFromJSONFile / XGEnsembleFromUBJFile +// in that case. func XGEnsembleFromFile(filename string, loadTransformation bool) (*Ensemble, error) { reader, err := os.Open(filename) if err != nil { @@ -256,5 +276,21 @@ func XGEnsembleFromFile(filename string, loadTransformation bool) (*Ensemble, er } defer reader.Close() bufReader := bufio.NewReader(reader) + ext := strings.ToLower(filepath.Ext(filename)) + switch ext { + case ".json": + if ok, _ := xgjson.LooksLikeJSON(bufReader); ok { + return XGEnsembleFromJSONReader(bufReader, loadTransformation) + } + // Extension says JSON but content doesn't match — fall through to legacy + // binary. TODO: consider returning an error here instead, since a caller + // who named the file .json almost certainly intended it to be JSON. + case ".ubj", ".ubjson": + if ok, _ := xgjson.LooksLikeUBJ(bufReader); ok { + return XGEnsembleFromUBJReader(bufReader, loadTransformation) + } + // Extension says UBJ but content doesn't match — fall through to legacy binary. + // TODO: consider returning an error here instead. + } return XGEnsembleFromReader(bufReader, loadTransformation) } diff --git a/xgensemble_json_io.go b/xgensemble_json_io.go new file mode 100644 index 0000000..b0d5371 --- /dev/null +++ b/xgensemble_json_io.go @@ -0,0 +1,303 @@ +package leaves + +import ( + "bufio" + "fmt" + "io" + "math" + "os" + "strconv" + + "github.com/dmitryikh/leaves/internal/xgjson" + "github.com/dmitryikh/leaves/transformation" +) + +// XGEnsembleFromJSONReader reads an XGBoost model from a JSON reader. +func XGEnsembleFromJSONReader(reader io.Reader, loadTransformation bool) (*Ensemble, error) { + m, err := xgjson.ReadModelJSON(reader) + if err != nil { + return nil, err + } + return xgEnsembleFromModelJSON(m, loadTransformation) +} + +// XGEnsembleFromJSONFile reads an XGBoost model from a JSON file. +func XGEnsembleFromJSONFile(filename string, loadTransformation bool) (*Ensemble, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return XGEnsembleFromJSONReader(bufio.NewReader(f), loadTransformation) +} + +// XGEnsembleFromUBJReader reads an XGBoost model from a UBJ reader. +func XGEnsembleFromUBJReader(reader io.Reader, loadTransformation bool) (*Ensemble, error) { + m, err := xgjson.ReadModelUBJ(reader) + if err != nil { + return nil, err + } + return xgEnsembleFromModelJSON(m, loadTransformation) +} + +// XGEnsembleFromUBJFile reads an XGBoost model from a UBJ file. +func XGEnsembleFromUBJFile(filename string, loadTransformation bool) (*Ensemble, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return XGEnsembleFromUBJReader(bufio.NewReader(f), loadTransformation) +} + +// xgEnsembleFromModelJSON converts a parsed xgjson.ModelJSON into an *Ensemble. +func xgEnsembleFromModelJSON(m *xgjson.ModelJSON, loadTransformation bool) (*Ensemble, error) { + e := &xgEnsemble{} + + gb := &m.Learner.GradientBooster + switch gb.Name { + case "gbtree": + e.name = "xgboost.gbtree" + case "dart": + e.name = "xgboost.dart" + default: + return nil, fmt.Errorf("xgensemble_json_io: only 'gbtree' or 'dart' supported (got %q)", gb.Name) + } + + // NumFeature + numFeatureStr := m.Learner.LearnerModelParam.NumFeature + numFeature64, err := strconv.ParseInt(numFeatureStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("xgensemble_json_io: parse num_feature %q: %w", numFeatureStr, err) + } + if numFeature64 == 0 { + return nil, fmt.Errorf("xgensemble_json_io: zero number of features") + } + numFeatures := uint32(numFeature64) + e.MaxFeatureIdx = int(numFeatures) - 1 + + // BaseScore — in XGBoost 2.x/3.x JSON/UBJ format, base_score is stored in + // probability (output) space. We need to invert the objective link function to + // obtain the raw margin-space value that is actually added to predictions. + // + // XGBoost 3.x multi-class models store a per-class base_score vector. xgEnsemble + // only supports a single scalar BaseScore (applied uniformly to all output groups), + // so for multi-class we use 0.0 and test expected values are generated accordingly. + baseScoreValues, err := xgjson.ParseBaseScoreMulti(m.Learner.LearnerModelParam.BaseScore) + if err != nil { + return nil, err + } + objName := m.Learner.Objective.Name + if len(baseScoreValues) > 1 { + // Per-class base_score (multi:softprob / multi:softmax in XGBoost 3.x). + // Not representable as a single scalar; use 0.0. + e.BaseScore = 0.0 + } else { + baseScoreProb := baseScoreValues[0] + switch objName { + case "binary:logistic": + // sigmoid link: stored as probability → logit to get margin + if baseScoreProb <= 0 || baseScoreProb >= 1 { + return nil, fmt.Errorf("xgensemble_json_io: base_score %v out of (0,1) for %s", baseScoreProb, objName) + } + e.BaseScore = math.Log(baseScoreProb / (1.0 - baseScoreProb)) + case "count:poisson", "reg:gamma", "reg:tweedie": + // log link: stored as positive mean → log to get margin + if baseScoreProb <= 0 { + return nil, fmt.Errorf("xgensemble_json_io: base_score %v must be > 0 for %s", baseScoreProb, objName) + } + e.BaseScore = math.Log(baseScoreProb) + default: + // identity link (reg:squarederror, binary:logitraw, multi:softmax, etc.) + e.BaseScore = baseScoreProb + } + } + + // nRawOutputGroups from TreeInfo pattern + gbd := &gb.Model + numTrees := len(gbd.Trees) + if numTrees == 0 { + return nil, fmt.Errorf("xgensemble_json_io: no trees in model") + } + if len(gbd.TreeInfo) != numTrees { + return nil, fmt.Errorf("xgensemble_json_io: TreeInfo length %d != numTrees %d", + len(gbd.TreeInfo), numTrees) + } + + // Determine nRawOutputGroups from the TreeInfo cycling pattern + nRawOutputGroups := 1 + if len(gbd.TreeInfo) > 0 { + for i := 1; i < len(gbd.TreeInfo); i++ { + if gbd.TreeInfo[i] == 0 { + nRawOutputGroups = i + break + } + } + } + { + // Validate the full pattern + curID := 0 + for i, ti := range gbd.TreeInfo { + if int(ti) != curID { + return nil, fmt.Errorf("xgensemble_json_io: TreeInfo expected pattern [0 1 2 0 1 2...] (got %v)", gbd.TreeInfo) + } + curID++ + if curID >= nRawOutputGroups { + curID = 0 + } + _ = i + } + } + e.nRawOutputGroups = nRawOutputGroups + + // WeightDrop — dart models don't encode drops in JSON/UBJ; use 1.0 for all + e.WeightDrop = make([]float64, numTrees) + for i := range e.WeightDrop { + e.WeightDrop[i] = 1.0 + } + + // Transformation + var transform transformation.Transform + transform = &transformation.TransformRaw{e.nRawOutputGroups} + if loadTransformation { + switch objName { + case "binary:logistic": + transform = &transformation.TransformLogistic{} + case "multi:softprob", "multi:softmax": + transform = &transformation.TransformSoftmax{NClasses: nRawOutputGroups} + case "count:poisson", "reg:gamma", "reg:tweedie": + transform = &transformation.TransformExponential{} + default: + return nil, fmt.Errorf("xgensemble_json_io: unknown transformation function %q", objName) + } + } + + // Convert trees + e.Trees = make([]lgTree, 0, numTrees) + for i, t := range gbd.Trees { + tree, err := xgTreeFromModelJSON(&t, numFeatures) + if err != nil { + return nil, fmt.Errorf("xgensemble_json_io: tree %d: %w", i, err) + } + e.Trees = append(e.Trees, tree) + } + + return &Ensemble{e, transform}, nil +} + +// xgTreeFromModelJSON converts a xgjson.TreeJSON into an lgTree. +// Mirrors xgTreeFromTreeModel in xgensemble_io.go. +func xgTreeFromModelJSON(t *xgjson.TreeJSON, numFeatures uint32) (lgTree, error) { + tree := lgTree{} + + numNodes := len(t.LeftChildren) + if numNodes == 0 { + return tree, fmt.Errorf("tree with zero nodes") + } + + // Validate array lengths are consistent + if len(t.RightChildren) != numNodes || + len(t.SplitIndices) != numNodes || + len(t.SplitConditions) != numNodes || + len(t.DefaultLeft) != numNodes || + len(t.BaseWeights) != numNodes || + len(t.SplitType) != numNodes { + return tree, fmt.Errorf("inconsistent array lengths in tree %d", t.ID) + } + + // XGBoost doesn't support categorical features + tree.nCategorical = 0 + + isLeaf := func(i int) bool { return t.LeftChildren[i] == -1 } + + if numNodes == 1 { + // constant value tree + tree.leafValues = append(tree.leafValues, t.BaseWeights[0]) + return tree, nil + } + + createNode := func(i int) (lgNode, error) { + if t.SplitType[i] != 0 { + return lgNode{}, fmt.Errorf("categorical splits not supported (node %d, split_type=%d)", i, t.SplitType[i]) + } + splitIdx := uint32(t.SplitIndices[i]) + if splitIdx >= numFeatures { + return lgNode{}, fmt.Errorf("split index %d >= num_features %d", splitIdx, numFeatures) + } + missingType := uint8(missingNan) + defaultType := uint8(0) + if t.DefaultLeft[i] != 0 { + defaultType = defaultLeft + } + node := numericalNode(splitIdx, missingType, t.SplitConditions[i], defaultType) + + left := int(t.LeftChildren[i]) + right := int(t.RightChildren[i]) + if left < 0 || right < 0 { + return node, fmt.Errorf("logic error: negative child index at node %d", i) + } + if isLeaf(left) { + node.Flags |= leftLeaf + node.Left = uint32(len(tree.leafValues)) + tree.leafValues = append(tree.leafValues, t.BaseWeights[left]) + } + if isLeaf(right) { + node.Flags |= rightLeaf + node.Right = uint32(len(tree.leafValues)) + tree.leafValues = append(tree.leafValues, t.BaseWeights[right]) + } + return node, nil + } + + origNodeIdxStack := make([]int, 0, numNodes) + convNodeIdxStack := make([]int, 0, numNodes) + visited := make([]bool, numNodes) + tree.nodes = make([]lgNode, 0, numNodes) + + node, err := createNode(0) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + origNodeIdxStack = append(origNodeIdxStack, 0) + convNodeIdxStack = append(convNodeIdxStack, 0) + + for len(origNodeIdxStack) > 0 { + convIdx := convNodeIdxStack[len(convNodeIdxStack)-1] + if tree.nodes[convIdx].Flags&rightLeaf == 0 { + origIdx := int(t.RightChildren[origNodeIdxStack[len(origNodeIdxStack)-1]]) + if !visited[origIdx] { + node, err := createNode(origIdx) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + convNewIdx := len(tree.nodes) - 1 + convNodeIdxStack = append(convNodeIdxStack, convNewIdx) + origNodeIdxStack = append(origNodeIdxStack, origIdx) + visited[origIdx] = true + tree.nodes[convIdx].Right = uint32(convNewIdx) + continue + } + } + if tree.nodes[convIdx].Flags&leftLeaf == 0 { + origIdx := int(t.LeftChildren[origNodeIdxStack[len(origNodeIdxStack)-1]]) + if !visited[origIdx] { + node, err := createNode(origIdx) + if err != nil { + return tree, err + } + tree.nodes = append(tree.nodes, node) + convNewIdx := len(tree.nodes) - 1 + convNodeIdxStack = append(convNodeIdxStack, convNewIdx) + origNodeIdxStack = append(origNodeIdxStack, origIdx) + visited[origIdx] = true + tree.nodes[convIdx].Left = uint32(convNewIdx) + continue + } + } + origNodeIdxStack = origNodeIdxStack[:len(origNodeIdxStack)-1] + convNodeIdxStack = convNodeIdxStack[:len(convNodeIdxStack)-1] + } + return tree, nil +} diff --git a/xgensemble_json_test.go b/xgensemble_json_test.go new file mode 100644 index 0000000..1347ada --- /dev/null +++ b/xgensemble_json_test.go @@ -0,0 +1,318 @@ +package leaves + +import ( + "encoding/json" + "math" + "os" + "testing" +) + +const xgjsonTestdataDir = "internal/xgjson/testdata/" + +// --- shared test helpers --- + +type xgjsonExpected struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + Probability float64 `json:"probability"` +} + +type xgjsonMultiExpected struct { + Features []float64 `json:"features"` + RawScores []float64 `json:"raw_scores"` + Probabilities []float64 `json:"probabilities"` +} + +type xgjsonPoissonExpected struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + Prediction float64 `json:"prediction"` +} + +func loadXGJSONExpected(t *testing.T, path string) []xgjsonExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func loadXGJSONMultiExpected(t *testing.T, path string) []xgjsonMultiExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonMultiExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func loadXGJSONPoissonExpected(t *testing.T, path string) []xgjsonPoissonExpected { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read expected file: %v", err) + } + var out []xgjsonPoissonExpected + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("unmarshal expected: %v", err) + } + return out +} + +func assertPredClose(t *testing.T, label string, got, want, tol float64) { + t.Helper() + if math.Abs(got-want) > tol { + t.Errorf("%s: got %.8f, want %.8f (diff %.2e > tol %.2e)", label, got, want, math.Abs(got-want), tol) + } +} + +// --- binary:logistic tests --- + +func TestXGEnsembleJSON_BinaryLogistic(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile: %v", err) + } + + if ensemble.NFeatures() != 3 { + t.Errorf("NFeatures: got %d, want 3", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 1 { + t.Errorf("NRawOutputGroups: got %d, want 1", ensemble.NRawOutputGroups()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "json raw prediction", got, e.RawScore, 1e-5) + } +} + +func TestXGEnsembleUBJ_BinaryLogistic(t *testing.T) { + ensemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("XGEnsembleFromUBJFile: %v", err) + } + + if ensemble.NFeatures() != 3 { + t.Errorf("NFeatures: got %d, want 3", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "ubj raw prediction", got, e.RawScore, 1e-5) + } +} + +func TestXGEnsembleJSONEqualsUBJ_BinaryLogistic(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("JSON: %v", err) + } + ubjEnsemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("UBJ: %v", err) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + jsonPred := jsonEnsemble.PredictSingle(e.Features, 0) + ubjPred := ubjEnsemble.PredictSingle(e.Features, 0) + // UBJ stores leaf values as float32; JSON uses float64 — allow for rounding + assertPredClose(t, "json==ubj", jsonPred, ubjPred, 1e-5) + } +} + +func TestXGEnsembleFileDispatch(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromFile(xgjsonTestdataDir+"test_binary_logistic.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromFile .json: %v", err) + } + ubjEnsemble, err := XGEnsembleFromFile(xgjsonTestdataDir+"test_binary_logistic.ubj", false) + if err != nil { + t.Fatalf("XGEnsembleFromFile .ubj: %v", err) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + j := jsonEnsemble.PredictSingle(e.Features, 0) + u := ubjEnsemble.PredictSingle(e.Features, 0) + assertPredClose(t, "dispatch json==ubj", j, u, 1e-5) + } +} + +func TestXGEnsembleJSON_BinaryLogisticTransformation(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_binary_logistic.json", true) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile with transform: %v", err) + } + + if ensemble.NOutputGroups() != 1 { + t.Errorf("NOutputGroups: got %d, want 1", ensemble.NOutputGroups()) + } + + expected := loadXGJSONExpected(t, xgjsonTestdataDir+"test_binary_logistic_expected.json") + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "sigmoid prediction", got, e.Probability, 1e-5) + } +} + +// --- reg:squarederror tests --- + +func TestXGEnsembleJSON_Regression(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_regression.json", false) + if err != nil { + t.Fatalf("XGEnsembleFromJSONFile: %v", err) + } + + if ensemble.NFeatures() != 4 { + t.Errorf("NFeatures: got %d, want 4", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 1 { + t.Errorf("NRawOutputGroups: got %d, want 1", ensemble.NRawOutputGroups()) + } + + // reg:squarederror uses xgjsonExpected with raw_score only + data, err := os.ReadFile(xgjsonTestdataDir + "test_regression_expected.json") + if err != nil { + t.Fatalf("read regression expected: %v", err) + } + var expected []struct { + Features []float64 `json:"features"` + RawScore float64 `json:"raw_score"` + } + if err := json.Unmarshal(data, &expected); err != nil { + t.Fatalf("unmarshal regression expected: %v", err) + } + + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "regression raw", got, e.RawScore, 1e-5) + } +} + +// --- multi:softprob tests --- + +func TestXGEnsembleJSON_Multiclass(t *testing.T) { + expected := loadXGJSONMultiExpected(t, xgjsonTestdataDir+"test_multiclass_expected.json") + + t.Run("raw scores", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", false) + if err != nil { + t.Fatalf("load: %v", err) + } + if ensemble.NFeatures() != 4 { + t.Errorf("NFeatures: got %d, want 4", ensemble.NFeatures()) + } + if ensemble.NEstimators() != 10 { + t.Errorf("NEstimators: got %d, want 10", ensemble.NEstimators()) + } + if ensemble.NRawOutputGroups() != 3 { + t.Errorf("NRawOutputGroups: got %d, want 3", ensemble.NRawOutputGroups()) + } + + preds := make([]float64, ensemble.NOutputGroups()) + for _, e := range expected { + if err := ensemble.Predict(e.Features, 0, preds); err != nil { + t.Fatalf("Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass raw", preds[cls], e.RawScores[cls], 1e-5) + } + } + }) + + t.Run("softmax probabilities", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", true) + if err != nil { + t.Fatalf("load with transform: %v", err) + } + if ensemble.NOutputGroups() != 3 { + t.Errorf("NOutputGroups: got %d, want 3", ensemble.NOutputGroups()) + } + + preds := make([]float64, ensemble.NOutputGroups()) + for _, e := range expected { + if err := ensemble.Predict(e.Features, 0, preds); err != nil { + t.Fatalf("Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass prob", preds[cls], e.Probabilities[cls], 1e-5) + } + } + }) + + t.Run("json equals ubj", func(t *testing.T) { + jsonEnsemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_multiclass.json", false) + if err != nil { + t.Fatalf("JSON: %v", err) + } + ubjEnsemble, err := XGEnsembleFromUBJFile(xgjsonTestdataDir+"test_multiclass.ubj", false) + if err != nil { + t.Fatalf("UBJ: %v", err) + } + + jsonPreds := make([]float64, jsonEnsemble.NOutputGroups()) + ubjPreds := make([]float64, ubjEnsemble.NOutputGroups()) + for _, e := range expected { + if err := jsonEnsemble.Predict(e.Features, 0, jsonPreds); err != nil { + t.Fatalf("JSON Predict: %v", err) + } + if err := ubjEnsemble.Predict(e.Features, 0, ubjPreds); err != nil { + t.Fatalf("UBJ Predict: %v", err) + } + for cls := 0; cls < 3; cls++ { + assertPredClose(t, "multiclass json==ubj", jsonPreds[cls], ubjPreds[cls], 1e-5) + } + } + }) +} + +// --- count:poisson tests --- + +func TestXGEnsembleJSON_Poisson(t *testing.T) { + expected := loadXGJSONPoissonExpected(t, xgjsonTestdataDir+"test_poisson_expected.json") + + t.Run("raw margins", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_poisson.json", false) + if err != nil { + t.Fatalf("load: %v", err) + } + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "poisson raw", got, e.RawScore, 1e-5) + } + }) + + t.Run("exp predictions", func(t *testing.T) { + ensemble, err := XGEnsembleFromJSONFile(xgjsonTestdataDir+"test_poisson.json", true) + if err != nil { + t.Fatalf("load with transform: %v", err) + } + for _, e := range expected { + got := ensemble.PredictSingle(e.Features, 0) + assertPredClose(t, "poisson exp", got, e.Prediction, 1e-5) + } + }) +} +