11#include " debug.h"
22
3+ #include " common.h"
34#include " log.h"
45
56#include < cmath>
7+ #include < regex>
68#include < string>
9+ #include < vector>
10+
11+ struct common_debug_cb_user_data ::impl {
12+ std::vector<uint8_t > data;
13+ std::vector<std::regex> tensor_filters;
14+ bool abort_on_nan{false };
15+ };
16+
17+ common_debug_cb_user_data::common_debug_cb_user_data () : pimpl(std::make_unique<impl>()) {}
18+ common_debug_cb_user_data::~common_debug_cb_user_data () = default ;
19+
20+ common_debug_cb_user_data::common_debug_cb_user_data (common_params & params, const std::vector<std::string> & filter_patterns, bool abort_on_nan)
21+ : pimpl(std::make_unique<impl>())
22+ {
23+ for (const auto & pattern : filter_patterns) {
24+ try {
25+ std::string anchored_pattern = " ^" + pattern;
26+ pimpl->tensor_filters .emplace_back (anchored_pattern, std::regex::optimize);
27+ } catch (const std::regex_error & e) {
28+ throw std::runtime_error (" Invalid regex pattern '" + pattern + " ': " + e.what ());
29+ }
30+ }
31+ pimpl->abort_on_nan = abort_on_nan;
32+
33+ params.cb_eval = common_debug_cb_eval;
34+ params.cb_eval_user_data = this ;
35+ }
736
837static std::string common_ggml_ne_string (const ggml_tensor * t) {
938 std::string str;
@@ -47,8 +76,7 @@ static float common_ggml_get_float_value(const uint8_t * data,
4776
4877#define INDENT " "
4978
50- template <bool abort>
51- void common_debug_print_tensor (uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
79+ static void common_debug_print_tensor (uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool abort_on_nan) {
5280 GGML_ASSERT (n > 0 );
5381 float sum = 0 ;
5482 for (int64_t i3 = 0 ; i3 < ne[3 ]; i3++) {
@@ -94,7 +122,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
94122 LOG (INDENT " sum = %f\n " , sum);
95123 }
96124
97- if constexpr (abort ) {
125+ if (abort_on_nan ) {
98126 if (std::isnan (sum)) {
99127 LOG (" encountered NaN - aborting\n " );
100128 exit (0 );
@@ -112,8 +140,9 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
112140 * @param user_data user data to pass at each call back
113141 * @return true to receive data or continue the graph, false otherwise
114142 */
115- template <bool abort_on_nan> bool common_debug_cb_eval (struct ggml_tensor * t, bool ask, void * user_data) {
116- auto * cb_data = (base_callback_data *) user_data;
143+ bool common_debug_cb_eval (struct ggml_tensor * t, bool ask, void * user_data) {
144+ auto * cb_data = (common_debug_cb_user_data *) user_data;
145+ auto * pimpl = cb_data->pimpl .get ();
117146
118147 const struct ggml_tensor * src0 = t->src [0 ];
119148 const struct ggml_tensor * src1 = t->src [1 ];
@@ -122,10 +151,10 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
122151 return true ; // Always retrieve data
123152 }
124153
125- bool matches_filter = cb_data ->tensor_filters .empty ();
154+ bool matches_filter = pimpl ->tensor_filters .empty ();
126155
127156 if (!matches_filter) {
128- for (const auto & filter : cb_data ->tensor_filters ) {
157+ for (const auto & filter : pimpl ->tensor_filters ) {
129158 if (std::regex_search (t->name , filter)) {
130159 matches_filter = true ;
131160 break ;
@@ -148,20 +177,14 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
148177
149178 if (!is_host) {
150179 auto n_bytes = ggml_nbytes (t);
151- cb_data ->data .resize (n_bytes);
152- ggml_backend_tensor_get (t, cb_data ->data .data (), 0 , n_bytes);
180+ pimpl ->data .resize (n_bytes);
181+ ggml_backend_tensor_get (t, pimpl ->data .data (), 0 , n_bytes);
153182 }
154183
155184 if (!ggml_is_quantized (t->type ) && matches_filter) {
156- uint8_t * data = is_host ? (uint8_t *) t->data : cb_data ->data .data ();
157- common_debug_print_tensor<abort_on_nan> (data, t->type , t->ne , t->nb , 3 );
185+ uint8_t * data = is_host ? (uint8_t *) t->data : pimpl ->data .data ();
186+ common_debug_print_tensor (data, t->type , t->ne , t->nb , 3 , pimpl-> abort_on_nan );
158187 }
159188
160189 return true ;
161190}
162-
163- // Explicit template instantiations
164- template bool common_debug_cb_eval<false >(ggml_tensor *, bool , void *);
165- template bool common_debug_cb_eval<true >(ggml_tensor *, bool , void *);
166- template void common_debug_print_tensor<false >(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t );
167- template void common_debug_print_tensor<true >(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t );
0 commit comments