Skip to content

Commit 2c13f90

Browse files
authored
Support graphviz plot for multi-target tree. (dmlc#10093)
1 parent e14c3b9 commit 2c13f90

File tree

3 files changed

+133
-63
lines changed

3 files changed

+133
-63
lines changed

include/xgboost/tree_model.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2014-2023 by Contributors
2+
* Copyright 2014-2024, XGBoost Contributors
33
* \file tree_model.h
44
* \brief model structure for tree
55
* \author Tianqi Chen
@@ -688,6 +688,9 @@ class RegTree : public Model {
688688
}
689689
return (*this)[nidx].DefaultLeft();
690690
}
691+
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
692+
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
693+
}
691694
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
692695
if (IsMultiTarget()) {
693696
return nidx == kRoot;

src/tree/tree_model.cc

Lines changed: 95 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2015-2023, XGBoost Contributors
2+
* Copyright 2015-2024, XGBoost Contributors
33
* \file tree_model.cc
44
* \brief model structure for tree
55
*/
@@ -8,14 +8,15 @@
88
#include <xgboost/json.h>
99
#include <xgboost/tree_model.h>
1010

11+
#include <array> // for array
1112
#include <cmath>
1213
#include <iomanip>
1314
#include <limits>
1415
#include <sstream>
1516
#include <type_traits>
1617

1718
#include "../common/categorical.h"
18-
#include "../common/common.h" // for EscapeU8
19+
#include "../common/common.h" // for EscapeU8
1920
#include "../predictor/predict_fn.h"
2021
#include "io_utils.h" // for GetElem
2122
#include "param.h"
@@ -31,26 +32,50 @@ namespace tree {
3132
DMLC_REGISTER_PARAMETER(TrainParam);
3233
}
3334

35+
namespace {
36+
template <typename Float>
37+
std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
38+
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
39+
static_assert(std::is_floating_point<Float>::value,
40+
"Use std::to_string instead for non-floating point values.");
41+
std::stringstream ss;
42+
ss << std::setprecision(kFloatMaxPrecision) << value;
43+
return ss.str();
44+
}
45+
46+
template <typename Float>
47+
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
48+
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
49+
static_assert(std::is_floating_point<Float>::value,
50+
"Use std::to_string instead for non-floating point values.");
51+
std::stringstream ss;
52+
ss << std::setprecision(kFloatMaxPrecision);
53+
if (value.Size() == 1) {
54+
ss << value(0);
55+
return ss.str();
56+
}
57+
CHECK_GE(limit, 2);
58+
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
59+
ss << "[";
60+
for (std::size_t i = 0; i < n; ++i) {
61+
ss << value(i) << ", ";
62+
}
63+
if (value.Size() > limit) {
64+
ss << "..., ";
65+
}
66+
ss << value(value.Size() - 1) << "]";
67+
return ss.str();
68+
}
69+
} // namespace
3470
/*!
3571
* \brief Base class for dump model implementation, modeling closely after code generator.
3672
*/
3773
class TreeGenerator {
3874
protected:
39-
static int32_t constexpr kFloatMaxPrecision =
40-
std::numeric_limits<bst_float>::max_digits10;
4175
FeatureMap const& fmap_;
4276
std::stringstream ss_;
4377
bool const with_stats_;
4478

45-
template <typename Float>
46-
static std::string ToStr(Float value) {
47-
static_assert(std::is_floating_point<Float>::value,
48-
"Use std::to_string instead for non-floating point values.");
49-
std::stringstream ss;
50-
ss << std::setprecision(kFloatMaxPrecision) << value;
51-
return ss.str();
52-
}
53-
5479
static std::string Tabs(uint32_t n) {
5580
std::string res;
5681
for (uint32_t i = 0; i < n; ++i) {
@@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
258283
kLeafTemplate,
259284
{{"{tabs}", SuperT::Tabs(depth)},
260285
{"{nid}", std::to_string(nid)},
261-
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
286+
{"{leaf}", ToStr(tree[nid].LeafValue())},
262287
{"{stats}", with_stats_ ?
263288
SuperT::Match(kStatTemplate,
264-
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
289+
{{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
265290
return result;
266291
}
267292

@@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
311336
static std::string const kQuantitiveTemplate =
312337
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
313338
auto cond = tree[nid].SplitCond();
314-
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
339+
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
315340
}
316341

317342
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
318343
auto cond = tree[nid].SplitCond();
319344
static std::string const kNodeTemplate =
320345
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
321-
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
346+
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
322347
}
323348

324349
std::string Categorical(RegTree const &tree, int32_t nid,
@@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
336361
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
337362
std::string const result = SuperT::Match(
338363
kStatTemplate,
339-
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
340-
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
364+
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
365+
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
341366
return result;
342367
}
343368

@@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
393418
std::string result = SuperT::Match(
394419
kLeafTemplate,
395420
{{"{nid}", std::to_string(nid)},
396-
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
421+
{"{leaf}", ToStr(tree[nid].LeafValue())},
397422
{"{stat}", with_stats_ ? SuperT::Match(
398423
kStatTemplate,
399424
{{"{sum_hess}",
400-
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
425+
ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
401426
return result;
402427
}
403428

@@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
468493
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
469494
R"I("missing": {missing})I";
470495
bst_float cond = tree[nid].SplitCond();
471-
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
496+
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
472497
}
473498

474499
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
@@ -477,16 +502,16 @@ class JsonGenerator : public TreeGenerator {
477502
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
478503
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
479504
R"I("missing": {missing})I";
480-
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
505+
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
481506
}
482507

483508
std::string NodeStat(RegTree const& tree, int32_t nid) const override {
484509
static std::string kStatTemplate =
485510
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
486511
auto result = SuperT::Match(
487512
kStatTemplate,
488-
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
489-
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
513+
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
514+
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
490515
return result;
491516
}
492517

@@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {
622647

623648
protected:
624649
template <bool is_categorical>
625-
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
650+
std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
626651
static std::string const kEdgeTemplate =
627652
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
628653
// Is this the default child for missing value?
629-
bool is_missing = tree[nid].DefaultChild() == child;
654+
bool is_missing = tree.DefaultChild(nidx) == child;
630655
std::string branch;
631656
if (is_categorical) {
632657
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
@@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
635660
}
636661
std::string buffer =
637662
SuperT::Match(kEdgeTemplate,
638-
{{"{nid}", std::to_string(nid)},
663+
{{"{nid}", std::to_string(nidx)},
639664
{"{child}", std::to_string(child)},
640665
{"{color}", is_missing ? param_.yes_color : param_.no_color},
641666
{"{branch}", branch}});
@@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {
644669

645670
// Only indicator is different, so we combine all different node types into this
646671
// function.
647-
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
648-
auto split_index = tree[nid].SplitIndex();
649-
auto cond = tree[nid].SplitCond();
672+
std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
673+
auto split_index = tree.SplitIndex(nidx);
674+
auto cond = tree.SplitCond(nidx);
650675
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
651676

652677
bool has_less =
653678
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
654679
std::string result =
655-
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
680+
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
656681
{"{fname}", GetFeatureName(fmap_, split_index)},
657682
{"{<}", has_less ? "<" : ""},
658-
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
683+
{"{cond}", has_less ? ToStr(cond) : ""},
659684
{"{params}", param_.condition_node_params}});
660685

661-
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
662-
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
686+
result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
687+
result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);
663688

664689
return result;
665690
};
666691

667-
std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
692+
std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
668693
static std::string const kLabelTemplate =
669694
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
670-
auto cats = GetSplitCategories(tree, nid);
695+
auto cats = GetSplitCategories(tree, nidx);
671696
auto cats_str = PrintCatsAsSet(cats);
672-
auto split_index = tree[nid].SplitIndex();
697+
auto split_index = tree.SplitIndex(nidx);
673698

674699
std::string result =
675-
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
700+
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
676701
{"{fname}", GetFeatureName(fmap_, split_index)},
677702
{"{cond}", cats_str},
678703
{"{params}", param_.condition_node_params}});
679704

680-
result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
681-
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
705+
result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
706+
result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);
682707

683708
return result;
684709
}
685710

686-
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
687-
static std::string const kLeafTemplate =
688-
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
689-
auto result = SuperT::Match(kLeafTemplate, {
690-
{"{nid}", std::to_string(nid)},
691-
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
692-
{"{params}", param_.leaf_node_params}});
693-
return result;
694-
};
711+
std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
712+
static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
713+
// hardcoded limit to avoid dumping long arrays into dot graph.
714+
bst_target_t constexpr kLimit{3};
715+
if (tree.IsMultiTarget()) {
716+
auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
717+
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
718+
{"{leaf-value}", ToStr(value, kLimit)},
719+
{"{params}", param_.leaf_node_params}});
720+
return result;
721+
} else {
722+
auto value = tree[nidx].LeafValue();
723+
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
724+
{"{leaf-value}", ToStr(value)},
725+
{"{params}", param_.leaf_node_params}});
726+
return result;
727+
}
728+
}
695729

696-
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
697-
if (tree[nid].IsLeaf()) {
698-
return this->LeafNode(tree, nid, depth);
730+
std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
731+
if (tree.IsLeaf(nidx)) {
732+
return this->LeafNode(tree, nidx, depth);
699733
}
700734
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
701-
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
702-
? this->Categorical(tree, nid, depth)
703-
: this->PlainNode(tree, nid, depth);
735+
auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
736+
? this->Categorical(tree, nidx, depth)
737+
: this->PlainNode(tree, nidx, depth);
704738
auto result = SuperT::Match(
705739
kNodeTemplate,
706740
{{"{parent}", node},
707-
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
708-
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
741+
{"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
742+
{"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
709743
return result;
710744
}
711745

@@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
733767
constexpr bst_node_t RegTree::kRoot;
734768

735769
std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
736-
CHECK(!IsMultiTarget());
770+
if (this->IsMultiTarget() && format != "dot") {
771+
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
772+
}
737773
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
738774
builder->BuildTree(*this);
739775

tests/cpp/tree/test_multi_target_tree_model.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
/**
2-
* Copyright 2023 by XGBoost Contributors
2+
* Copyright 2023-2024, XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/context.h> // for Context
66
#include <xgboost/multi_target_tree_model.h>
77
#include <xgboost/tree_model.h> // for RegTree
88

99
namespace xgboost {
10-
TEST(MultiTargetTree, JsonIO) {
10+
namespace {
11+
auto MakeTreeForTest() {
1112
bst_target_t n_targets{3};
1213
bst_feature_t n_features{4};
1314
RegTree tree{n_targets, n_features};
14-
ASSERT_TRUE(tree.IsMultiTarget());
15+
CHECK(tree.IsMultiTarget());
1516
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
1617
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
1718
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
1819
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
1920
left_weight.HostView(), right_weight.HostView());
21+
return tree;
22+
}
23+
} // namespace
24+
25+
TEST(MultiTargetTree, JsonIO) {
26+
auto tree = MakeTreeForTest();
2027
ASSERT_EQ(tree.NumNodes(), 3);
2128
ASSERT_EQ(tree.NumTargets(), 3);
2229
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
@@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
4451
loaded.SaveModel(&jtree1);
4552
check_jtree(jtree1, tree);
4653
}
54+
55+
TEST(MultiTargetTree, DumpDot) {
56+
auto tree = MakeTreeForTest();
57+
auto n_features = tree.NumFeatures();
58+
FeatureMap fmap;
59+
for (bst_feature_t f = 0; f < n_features; ++f) {
60+
auto name = "feat_" + std::to_string(f);
61+
fmap.PushBack(f, name.c_str(), "q");
62+
}
63+
auto str = tree.DumpModel(fmap, true, "dot");
64+
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
65+
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
66+
67+
{
68+
bst_target_t n_targets{4};
69+
bst_feature_t n_features{4};
70+
RegTree tree{n_targets, n_features};
71+
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
72+
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
73+
weight.HostView(), weight.HostView());
74+
auto str = tree.DumpModel(fmap, true, "dot");
75+
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
76+
}
77+
}
4778
} // namespace xgboost

0 commit comments

Comments
 (0)