11/* *
2- * Copyright 2015-2023 , XGBoost Contributors
2+ * Copyright 2015-2024 , XGBoost Contributors
33 * \file tree_model.cc
44 * \brief model structure for tree
55 */
88#include < xgboost/json.h>
99#include < xgboost/tree_model.h>
1010
11+ #include < array> // for array
1112#include < cmath>
1213#include < iomanip>
1314#include < limits>
1415#include < sstream>
1516#include < type_traits>
1617
1718#include " ../common/categorical.h"
18- #include " ../common/common.h" // for EscapeU8
19+ #include " ../common/common.h" // for EscapeU8
1920#include " ../predictor/predict_fn.h"
2021#include " io_utils.h" // for GetElem
2122#include " param.h"
@@ -31,26 +32,50 @@ namespace tree {
3132DMLC_REGISTER_PARAMETER (TrainParam);
3233}
3334
35+ namespace {
36+ template <typename Float>
37+ std::enable_if_t <std::is_floating_point_v<Float>, std::string> ToStr (Float value) {
38+ int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float >::max_digits10;
39+ static_assert (std::is_floating_point<Float>::value,
40+ " Use std::to_string instead for non-floating point values." );
41+ std::stringstream ss;
42+ ss << std::setprecision (kFloatMaxPrecision ) << value;
43+ return ss.str ();
44+ }
45+
46+ template <typename Float>
47+ std::string ToStr (linalg::VectorView<Float> value, bst_target_t limit) {
48+ int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float >::max_digits10;
49+ static_assert (std::is_floating_point<Float>::value,
50+ " Use std::to_string instead for non-floating point values." );
51+ std::stringstream ss;
52+ ss << std::setprecision (kFloatMaxPrecision );
53+ if (value.Size () == 1 ) {
54+ ss << value (0 );
55+ return ss.str ();
56+ }
57+ CHECK_GE (limit, 2 );
58+ auto n = std::min (static_cast <bst_target_t >(value.Size () - 1 ), limit - 1 );
59+ ss << " [" ;
60+ for (std::size_t i = 0 ; i < n; ++i) {
61+ ss << value (i) << " , " ;
62+ }
63+ if (value.Size () > limit) {
64+ ss << " ..., " ;
65+ }
66+ ss << value (value.Size () - 1 ) << " ]" ;
67+ return ss.str ();
68+ }
69+ } // namespace
3470/* !
3571 * \brief Base class for dump model implementation, modeling closely after code generator.
3672 */
3773class TreeGenerator {
3874 protected:
39- static int32_t constexpr kFloatMaxPrecision =
40- std::numeric_limits<bst_float>::max_digits10;
4175 FeatureMap const & fmap_;
4276 std::stringstream ss_;
4377 bool const with_stats_;
4478
45- template <typename Float>
46- static std::string ToStr (Float value) {
47- static_assert (std::is_floating_point<Float>::value,
48- " Use std::to_string instead for non-floating point values." );
49- std::stringstream ss;
50- ss << std::setprecision (kFloatMaxPrecision ) << value;
51- return ss.str ();
52- }
53-
5479 static std::string Tabs (uint32_t n) {
5580 std::string res;
5681 for (uint32_t i = 0 ; i < n; ++i) {
@@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
258283 kLeafTemplate ,
259284 {{" {tabs}" , SuperT::Tabs (depth)},
260285 {" {nid}" , std::to_string (nid)},
261- {" {leaf}" , SuperT:: ToStr (tree[nid].LeafValue ())},
286+ {" {leaf}" , ToStr (tree[nid].LeafValue ())},
262287 {" {stats}" , with_stats_ ?
263288 SuperT::Match (kStatTemplate ,
264- {{" {cover}" , SuperT:: ToStr (tree.Stat (nid).sum_hess )}}) : " " }});
289+ {{" {cover}" , ToStr (tree.Stat (nid).sum_hess )}}) : " " }});
265290 return result;
266291 }
267292
@@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
311336 static std::string const kQuantitiveTemplate =
312337 " {tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}" ;
313338 auto cond = tree[nid].SplitCond ();
314- return SplitNodeImpl (tree, nid, kQuantitiveTemplate , SuperT:: ToStr (cond), depth);
339+ return SplitNodeImpl (tree, nid, kQuantitiveTemplate , ToStr (cond), depth);
315340 }
316341
317342 std::string PlainNode (RegTree const & tree, int32_t nid, uint32_t depth) const override {
318343 auto cond = tree[nid].SplitCond ();
319344 static std::string const kNodeTemplate =
320345 " {tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}" ;
321- return SplitNodeImpl (tree, nid, kNodeTemplate , SuperT:: ToStr (cond), depth);
346+ return SplitNodeImpl (tree, nid, kNodeTemplate , ToStr (cond), depth);
322347 }
323348
324349 std::string Categorical (RegTree const &tree, int32_t nid,
@@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
336361 static std::string const kStatTemplate = " ,gain={loss_chg},cover={sum_hess}" ;
337362 std::string const result = SuperT::Match (
338363 kStatTemplate ,
339- {{" {loss_chg}" , SuperT:: ToStr (tree.Stat (nid).loss_chg )},
340- {" {sum_hess}" , SuperT:: ToStr (tree.Stat (nid).sum_hess )}});
364+ {{" {loss_chg}" , ToStr (tree.Stat (nid).loss_chg )},
365+ {" {sum_hess}" , ToStr (tree.Stat (nid).sum_hess )}});
341366 return result;
342367 }
343368
@@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
393418 std::string result = SuperT::Match (
394419 kLeafTemplate ,
395420 {{" {nid}" , std::to_string (nid)},
396- {" {leaf}" , SuperT:: ToStr (tree[nid].LeafValue ())},
421+ {" {leaf}" , ToStr (tree[nid].LeafValue ())},
397422 {" {stat}" , with_stats_ ? SuperT::Match (
398423 kStatTemplate ,
399424 {{" {sum_hess}" ,
400- SuperT:: ToStr (tree.Stat (nid).sum_hess )}}) : " " }});
425+ ToStr (tree.Stat (nid).sum_hess )}}) : " " }});
401426 return result;
402427 }
403428
@@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
468493 R"I( "split_condition": {cond}, "yes": {left}, "no": {right}, )I"
469494 R"I( "missing": {missing})I" ;
470495 bst_float cond = tree[nid].SplitCond ();
471- return SplitNodeImpl (tree, nid, kQuantitiveTemplate , SuperT:: ToStr (cond), depth);
496+ return SplitNodeImpl (tree, nid, kQuantitiveTemplate , ToStr (cond), depth);
472497 }
473498
474499 std::string PlainNode (RegTree const & tree, int32_t nid, uint32_t depth) const override {
@@ -477,16 +502,16 @@ class JsonGenerator : public TreeGenerator {
477502 R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
478503 R"I( "split_condition": {cond}, "yes": {left}, "no": {right}, )I"
479504 R"I( "missing": {missing})I" ;
480- return SplitNodeImpl (tree, nid, kNodeTemplate , SuperT:: ToStr (cond), depth);
505+ return SplitNodeImpl (tree, nid, kNodeTemplate , ToStr (cond), depth);
481506 }
482507
483508 std::string NodeStat (RegTree const & tree, int32_t nid) const override {
484509 static std::string kStatTemplate =
485510 R"S( , "gain": {loss_chg}, "cover": {sum_hess})S" ;
486511 auto result = SuperT::Match (
487512 kStatTemplate ,
488- {{" {loss_chg}" , SuperT:: ToStr (tree.Stat (nid).loss_chg )},
489- {" {sum_hess}" , SuperT:: ToStr (tree.Stat (nid).sum_hess )}});
513+ {{" {loss_chg}" , ToStr (tree.Stat (nid).loss_chg )},
514+ {" {sum_hess}" , ToStr (tree.Stat (nid).sum_hess )}});
490515 return result;
491516 }
492517
@@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
622647
623648 protected:
624649 template <bool is_categorical>
625- std::string BuildEdge (RegTree const &tree, bst_node_t nid , int32_t child, bool left) const {
650+ std::string BuildEdge (RegTree const &tree, bst_node_t nidx , int32_t child, bool left) const {
626651 static std::string const kEdgeTemplate =
627652 " {nid} -> {child} [label=\" {branch}\" color=\" {color}\" ]\n " ;
628653 // Is this the default child for missing value?
629- bool is_missing = tree[nid] .DefaultChild () == child;
654+ bool is_missing = tree.DefaultChild (nidx ) == child;
630655 std::string branch;
631656 if (is_categorical) {
632657 branch = std::string{left ? " no" : " yes" } + std::string{is_missing ? " , missing" : " " };
@@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
635660 }
636661 std::string buffer =
637662 SuperT::Match (kEdgeTemplate ,
638- {{" {nid}" , std::to_string (nid )},
663+ {{" {nid}" , std::to_string (nidx )},
639664 {" {child}" , std::to_string (child)},
640665 {" {color}" , is_missing ? param_.yes_color : param_.no_color },
641666 {" {branch}" , branch}});
@@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
644669
645670 // Only indicator is different, so we combine all different node types into this
646671 // function.
647- std::string PlainNode (RegTree const & tree, int32_t nid , uint32_t ) const override {
648- auto split_index = tree[nid] .SplitIndex ();
649- auto cond = tree[nid] .SplitCond ();
672+ std::string PlainNode (RegTree const & tree, bst_node_t nidx , uint32_t ) const override {
673+ auto split_index = tree.SplitIndex (nidx );
674+ auto cond = tree.SplitCond (nidx );
650675 static std::string const kNodeTemplate = " {nid} [ label=\" {fname}{<}{cond}\" {params}]\n " ;
651676
652677 bool has_less =
653678 (split_index >= fmap_.Size ()) || fmap_.TypeOf (split_index) != FeatureMap::kIndicator ;
654679 std::string result =
655- SuperT::Match (kNodeTemplate , {{" {nid}" , std::to_string (nid )},
680+ SuperT::Match (kNodeTemplate , {{" {nid}" , std::to_string (nidx )},
656681 {" {fname}" , GetFeatureName (fmap_, split_index)},
657682 {" {<}" , has_less ? " <" : " " },
658- {" {cond}" , has_less ? SuperT:: ToStr (cond) : " " },
683+ {" {cond}" , has_less ? ToStr (cond) : " " },
659684 {" {params}" , param_.condition_node_params }});
660685
661- result += BuildEdge<false >(tree, nid , tree[nid] .LeftChild (), true );
662- result += BuildEdge<false >(tree, nid , tree[nid] .RightChild (), false );
686+ result += BuildEdge<false >(tree, nidx , tree.LeftChild (nidx ), true );
687+ result += BuildEdge<false >(tree, nidx , tree.RightChild (nidx ), false );
663688
664689 return result;
665690 };
666691
667- std::string Categorical (RegTree const & tree, int32_t nid , uint32_t ) const override {
692+ std::string Categorical (RegTree const & tree, bst_node_t nidx , uint32_t ) const override {
668693 static std::string const kLabelTemplate =
669694 " {nid} [ label=\" {fname}:{cond}\" {params}]\n " ;
670- auto cats = GetSplitCategories (tree, nid );
695+ auto cats = GetSplitCategories (tree, nidx );
671696 auto cats_str = PrintCatsAsSet (cats);
672- auto split_index = tree[nid] .SplitIndex ();
697+ auto split_index = tree.SplitIndex (nidx );
673698
674699 std::string result =
675- SuperT::Match (kLabelTemplate , {{" {nid}" , std::to_string (nid )},
700+ SuperT::Match (kLabelTemplate , {{" {nid}" , std::to_string (nidx )},
676701 {" {fname}" , GetFeatureName (fmap_, split_index)},
677702 {" {cond}" , cats_str},
678703 {" {params}" , param_.condition_node_params }});
679704
680- result += BuildEdge<true >(tree, nid , tree[nid] .LeftChild (), true );
681- result += BuildEdge<true >(tree, nid , tree[nid] .RightChild (), false );
705+ result += BuildEdge<true >(tree, nidx , tree.LeftChild (nidx ), true );
706+ result += BuildEdge<true >(tree, nidx , tree.RightChild (nidx ), false );
682707
683708 return result;
684709 }
685710
686- std::string LeafNode (RegTree const & tree, int32_t nid, uint32_t ) const override {
687- static std::string const kLeafTemplate =
688- " {nid} [ label=\" leaf={leaf-value}\" {params}]\n " ;
689- auto result = SuperT::Match (kLeafTemplate , {
690- {" {nid}" , std::to_string (nid)},
691- {" {leaf-value}" , ToStr (tree[nid].LeafValue ())},
692- {" {params}" , param_.leaf_node_params }});
693- return result;
694- };
711+ std::string LeafNode (RegTree const & tree, bst_node_t nidx, uint32_t ) const override {
712+ static std::string const kLeafTemplate = " {nid} [ label=\" leaf={leaf-value}\" {params}]\n " ;
713+ // hardcoded limit to avoid dumping long arrays into dot graph.
714+ bst_target_t constexpr kLimit {3 };
715+ if (tree.IsMultiTarget ()) {
716+ auto value = tree.GetMultiTargetTree ()->LeafValue (nidx);
717+ auto result = SuperT::Match (kLeafTemplate , {{" {nid}" , std::to_string (nidx)},
718+ {" {leaf-value}" , ToStr (value, kLimit )},
719+ {" {params}" , param_.leaf_node_params }});
720+ return result;
721+ } else {
722+ auto value = tree[nidx].LeafValue ();
723+ auto result = SuperT::Match (kLeafTemplate , {{" {nid}" , std::to_string (nidx)},
724+ {" {leaf-value}" , ToStr (value)},
725+ {" {params}" , param_.leaf_node_params }});
726+ return result;
727+ }
728+ }
695729
696- std::string BuildTree (RegTree const & tree, int32_t nid , uint32_t depth) override {
697- if (tree[nid] .IsLeaf ()) {
698- return this ->LeafNode (tree, nid , depth);
730+ std::string BuildTree (RegTree const & tree, bst_node_t nidx , uint32_t depth) override {
731+ if (tree.IsLeaf (nidx )) {
732+ return this ->LeafNode (tree, nidx , depth);
699733 }
700734 static std::string const kNodeTemplate = " {parent}\n {left}\n {right}" ;
701- auto node = tree.GetSplitTypes ()[nid ] == FeatureType::kCategorical
702- ? this ->Categorical (tree, nid , depth)
703- : this ->PlainNode (tree, nid , depth);
735+ auto node = tree.GetSplitTypes ()[nidx ] == FeatureType::kCategorical
736+ ? this ->Categorical (tree, nidx , depth)
737+ : this ->PlainNode (tree, nidx , depth);
704738 auto result = SuperT::Match (
705739 kNodeTemplate ,
706740 {{" {parent}" , node},
707- {" {left}" , this ->BuildTree (tree, tree[nid] .LeftChild (), depth+1 )},
708- {" {right}" , this ->BuildTree (tree, tree[nid] .RightChild (), depth+1 )}});
741+ {" {left}" , this ->BuildTree (tree, tree.LeftChild (nidx ), depth+1 )},
742+ {" {right}" , this ->BuildTree (tree, tree.RightChild (nidx ), depth+1 )}});
709743 return result;
710744 }
711745
@@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
733767constexpr bst_node_t RegTree::kRoot ;
734768
735769std::string RegTree::DumpModel (const FeatureMap& fmap, bool with_stats, std::string format) const {
736- CHECK (!IsMultiTarget ());
770+ if (this ->IsMultiTarget () && format != " dot" ) {
771+ LOG (FATAL) << format << " tree dump " << MTNotImplemented ();
772+ }
737773 std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create (format, fmap, with_stats)};
738774 builder->BuildTree (*this );
739775
0 commit comments