Skip to content

Commit 5594d13

Browse files
common: fix missing exports in llama-common (ggml-org#22340)
* common: refactor common/debug to move abort_on_nan into base_callback_data Passing bool abort_on_nan as template parameter for common_debug_cb_eval is unnecessary and creates an issue with LTO. It should just be a member of the base_callback_data instead. * cont : cleanup * common : use pimpl in debug.h to reduce header dependencies Move common_debug_cb_user_data's data members (std::regex, std::vector<uint8_t>) into a private impl struct in debug.cpp. This removes the includes of common.h and <regex> from debug.h, reducing transitive dependencies for any translation unit that includes the header. Assisted-by: llama.cpp:local pi --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent f535774 commit 5594d13

6 files changed

Lines changed: 74 additions & 60 deletions

File tree

common/debug.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
11
#include "debug.h"
22

3+
#include "common.h"
34
#include "log.h"
45

56
#include <cmath>
7+
#include <regex>
68
#include <string>
9+
#include <vector>
10+
11+
struct common_debug_cb_user_data::impl {
12+
std::vector<uint8_t> data;
13+
std::vector<std::regex> tensor_filters;
14+
bool abort_on_nan{false};
15+
};
16+
17+
common_debug_cb_user_data::common_debug_cb_user_data() : pimpl(std::make_unique<impl>()) {}
18+
common_debug_cb_user_data::~common_debug_cb_user_data() = default;
19+
20+
common_debug_cb_user_data::common_debug_cb_user_data(common_params & params, const std::vector<std::string> & filter_patterns, bool abort_on_nan)
21+
: pimpl(std::make_unique<impl>())
22+
{
23+
for (const auto & pattern : filter_patterns) {
24+
try {
25+
std::string anchored_pattern = "^" + pattern;
26+
pimpl->tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
27+
} catch (const std::regex_error & e) {
28+
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
29+
}
30+
}
31+
pimpl->abort_on_nan = abort_on_nan;
32+
33+
params.cb_eval = common_debug_cb_eval;
34+
params.cb_eval_user_data = this;
35+
}
736

837
static std::string common_ggml_ne_string(const ggml_tensor * t) {
938
std::string str;
@@ -47,8 +76,7 @@ static float common_ggml_get_float_value(const uint8_t * data,
4776

4877
#define INDENT " "
4978

50-
template <bool abort>
51-
void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
79+
static void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool abort_on_nan) {
5280
GGML_ASSERT(n > 0);
5381
float sum = 0;
5482
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
@@ -94,7 +122,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
94122
LOG(INDENT "sum = %f\n", sum);
95123
}
96124

97-
if constexpr (abort) {
125+
if (abort_on_nan) {
98126
if (std::isnan(sum)) {
99127
LOG("encountered NaN - aborting\n");
100128
exit(0);
@@ -112,8 +140,9 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
112140
* @param user_data user data to pass at each call back
113141
* @return true to receive data or continue the graph, false otherwise
114142
*/
115-
template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
116-
auto * cb_data = (base_callback_data *) user_data;
143+
bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
144+
auto * cb_data = (common_debug_cb_user_data *) user_data;
145+
auto * pimpl = cb_data->pimpl.get();
117146

118147
const struct ggml_tensor * src0 = t->src[0];
119148
const struct ggml_tensor * src1 = t->src[1];
@@ -122,10 +151,10 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
122151
return true; // Always retrieve data
123152
}
124153

125-
bool matches_filter = cb_data->tensor_filters.empty();
154+
bool matches_filter = pimpl->tensor_filters.empty();
126155

127156
if (!matches_filter) {
128-
for (const auto & filter : cb_data->tensor_filters) {
157+
for (const auto & filter : pimpl->tensor_filters) {
129158
if (std::regex_search(t->name, filter)) {
130159
matches_filter = true;
131160
break;
@@ -148,20 +177,14 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
148177

149178
if (!is_host) {
150179
auto n_bytes = ggml_nbytes(t);
151-
cb_data->data.resize(n_bytes);
152-
ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
180+
pimpl->data.resize(n_bytes);
181+
ggml_backend_tensor_get(t, pimpl->data.data(), 0, n_bytes);
153182
}
154183

155184
if (!ggml_is_quantized(t->type) && matches_filter) {
156-
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
157-
common_debug_print_tensor<abort_on_nan>(data, t->type, t->ne, t->nb, 3);
185+
uint8_t * data = is_host ? (uint8_t *) t->data : pimpl->data.data();
186+
common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, pimpl->abort_on_nan);
158187
}
159188

160189
return true;
161190
}
162-
163-
// Explicit template instantiations
164-
template bool common_debug_cb_eval<false>(ggml_tensor *, bool, void *);
165-
template bool common_debug_cb_eval<true>(ggml_tensor *, bool, void *);
166-
template void common_debug_print_tensor<false>(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t);
167-
template void common_debug_print_tensor<true>(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t);

common/debug.h

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,31 @@
11
#pragma once
2-
#include "common.h"
2+
3+
#include <memory>
34
#include <string>
45
#include <vector>
5-
#include <regex>
66

77
// common debug functions and structs
88

9-
// Print a tensor's detailed data
10-
// data - the tensor's data in byte format
11-
// type - the tensor's quantization type
12-
// ne - the tensor dimensions array
13-
// nb - the tensor strides array
14-
// n - the number of rows/columns to fully print
15-
template <bool abort_on_nan> void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n);
9+
struct common_params;
1610

1711
// Intended to use as callback for ggml_backend_sched_eval_callback
1812
// prints tensors that are processed in the computation graph
19-
// by default prints all tensors, but can be configured by creating a `base_callback_data` instance with
20-
// non-empty filter_patterns. See examples/debug.ccp for possible usage patterns
21-
// The template parameter determines whether an error should be thrown whenever a NaN is encountered
13+
// by default prints all tensors, but can be configured by creating a `common_debug_cb_user_data` instance with
14+
// non-empty filter_patterns. See examples/debug.cpp for possible usage patterns
15+
// `common_debug_cb_user_data` contains `abort_on_nan` flag that determines whether an error should be thrown whenever a NaN is encountered
2216
// in a tensor (useful for stopping debug sessions on first erroneous tensor)
2317
// The callback data will be passed as the third parameter (user_data)
24-
template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
25-
struct base_callback_data {
26-
std::vector<uint8_t> data;
27-
std::vector<std::regex> tensor_filters;
28-
29-
base_callback_data() = default;
30-
31-
base_callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
32-
for (const auto & pattern : filter_patterns) {
33-
try {
34-
std::string anchored_pattern = "^" + pattern;
35-
tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
36-
} catch (const std::regex_error & e) {
37-
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
38-
}
39-
}
40-
params.cb_eval = common_debug_cb_eval<false>;
41-
params.cb_eval_user_data = this;
42-
}
18+
bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
19+
20+
struct common_debug_cb_user_data {
21+
struct impl;
22+
std::unique_ptr<impl> pimpl;
23+
24+
common_debug_cb_user_data();
25+
~common_debug_cb_user_data();
26+
27+
common_debug_cb_user_data(const common_debug_cb_user_data &) = delete;
28+
common_debug_cb_user_data & operator=(const common_debug_cb_user_data &) = delete;
29+
30+
common_debug_cb_user_data(common_params & params, const std::vector<std::string> & filter_patterns, bool abort_on_nan = false);
4331
};

examples/debug/debug.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,14 @@ static bool run(llama_context * ctx, const common_params & params) {
202202
print_tokenized_prompt(ctx, tokens, params.prompt);
203203

204204
if (params.save_logits) {
205-
output_data output {ctx, model, params};
206-
std::filesystem::path model_path{params.model.path};
207-
std::string model_name{model_path.stem().string()};
208-
save_output_data(output, model_name, params.logits_output_dir);
205+
try {
206+
output_data output {ctx, model, params};
207+
std::filesystem::path model_path{params.model.path};
208+
std::string model_name{model_path.stem().string()};
209+
save_output_data(output, model_name, params.logits_output_dir);
210+
} catch (const std::exception & e) {
211+
LOG_ERR("%s : error saving logits: %s\n", __func__, e.what());
212+
}
209213
}
210214

211215
return true;
@@ -223,7 +227,7 @@ int main(int argc, char ** argv) {
223227
llama_backend_init();
224228
llama_numa_init(params.numa);
225229

226-
std::optional<base_callback_data> cb_data;
230+
std::optional<common_debug_cb_user_data> cb_data;
227231
if (!params.save_logits) {
228232
cb_data.emplace(params, params.tensor_filter);
229233
}

examples/eval-callback/eval-callback.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "debug.h"
44
#include "log.h"
55
#include "llama.h"
6-
#include "llama-cpp.h"
76

87
#include <clocale>
98
#include <string>
@@ -38,7 +37,7 @@ static bool run(llama_context * ctx, const common_params & params) {
3837
int main(int argc, char ** argv) {
3938
std::setlocale(LC_NUMERIC, "C");
4039

41-
base_callback_data cb_data;
40+
common_debug_cb_user_data cb_data;
4241

4342
common_params params;
4443

@@ -53,7 +52,7 @@ int main(int argc, char ** argv) {
5352

5453
// pass the callback to the backend scheduler
5554
// it will be executed for each node during the graph computation
56-
params.cb_eval = common_debug_cb_eval<false>;
55+
params.cb_eval = common_debug_cb_eval;
5756
params.cb_eval_user_data = &cb_data;
5857
params.warmup = false;
5958

tools/mtmd/debug/mtmd-debug.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ int main(int argc, char ** argv) {
7272

7373
mtmd::context_ptr ctx_mtmd;
7474
common_init_result_ptr llama_init;
75-
base_callback_data cb_data;
75+
common_debug_cb_user_data cb_data;
7676

7777
llama_init = common_init_from_params(params);
7878
{
@@ -89,7 +89,7 @@ int main(int argc, char ** argv) {
8989
{
9090
// always enable debug callback
9191
mparams.cb_eval_user_data = &cb_data;
92-
mparams.cb_eval = common_debug_cb_eval<false>;
92+
mparams.cb_eval = common_debug_cb_eval;
9393
}
9494
ctx_mtmd.reset(mtmd_init_from_file(clip_path, model, mparams));
9595
if (!ctx_mtmd.get()) {

tools/mtmd/mtmd-cli.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct mtmd_cli_context {
9090
int n_threads = 1;
9191
llama_pos n_past = 0;
9292

93-
base_callback_data cb_data;
93+
common_debug_cb_user_data cb_data;
9494

9595
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
9696
model = llama_init->model();
@@ -145,7 +145,7 @@ struct mtmd_cli_context {
145145
mparams.image_max_tokens = params.image_max_tokens;
146146
if (std::getenv("MTMD_DEBUG_GRAPH") != nullptr) {
147147
mparams.cb_eval_user_data = &cb_data;
148-
mparams.cb_eval = common_debug_cb_eval<false>;
148+
mparams.cb_eval = common_debug_cb_eval;
149149
}
150150
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
151151
if (!ctx_vision.get()) {

0 commit comments

Comments
 (0)