Skip to content

Commit 05374a5

Browse files
committed
add state to callbacks
1 parent 6dc65d8 commit 05374a5

10 files changed

Lines changed: 224 additions & 23 deletions

File tree

examples/c/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(C_EXAMPLES
1212
example_json
1313
example_print_stream
1414
example_callback
15+
example_callback_with_state
1516
example_pardiso_mkl
1617
)
1718

examples/c/example_callback.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include <Clarabel.h>
44

55

6-
int custom_callback(ClarabelDefaultInfo *info)
6+
int custom_callback(ClarabelDefaultInfo *info, void* _userdata)
77
{
88
// This function is called at each iteration of the solver.
99
// You can use it to monitor the progress of the solver or
@@ -71,7 +71,7 @@ int main(void)
7171
);
7272

7373
// configure a custom callback function
74-
clarabel_DefaultSolver_set_termination_callback(solver,custom_callback);
74+
clarabel_DefaultSolver_set_termination_callback(solver,custom_callback, NULL);
7575

7676
// Solve
7777
clarabel_DefaultSolver_solve(solver);
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "utils.h"
2+
#include <stdio.h>
3+
#include <Clarabel.h>
4+
5+
6+
typedef struct
7+
{
8+
int count;
9+
} CallbackData;
10+
11+
12+
int custom_callback(ClarabelDefaultInfo *info, void* userdata)
13+
{
14+
// Cast the userdata pointer back to our struct type
15+
CallbackData* data = (CallbackData*)userdata;
16+
17+
// Access and modify the state
18+
data->count++;
19+
20+
// Return 0 to continue. Anything else to stop.
21+
if (data->count < 3) {
22+
printf("tick\n");
23+
return 0; //continue
24+
}
25+
else {
26+
printf("BOOM!\n");
27+
return 1; // stop
28+
}
29+
}
30+
31+
int main(void)
32+
{
33+
// 2 x 2 zero matrix
34+
ClarabelCscMatrix P;
35+
clarabel_CscMatrix_init(
36+
&P,
37+
2, // row
38+
2, // col
39+
(uintptr_t[]){ 0, 0, 0 }, // colptr
40+
NULL, // rowval
41+
NULL // nzval
42+
);
43+
44+
ClarabelFloat q[2] = { 1.0, -1.0 };
45+
46+
// a 2-d box constraint, separated into 4 inequalities.
47+
// A = [I; -I]
48+
ClarabelCscMatrix A;
49+
clarabel_CscMatrix_init(
50+
&A,
51+
4, // row
52+
2, // col
53+
(uintptr_t[]){ 0, 2, 4 }, // colptr
54+
(uintptr_t[]){ 0, 2, 1, 3 }, // rowval
55+
(ClarabelFloat[]){ 1.0, -1.0, 1.0, -1.0 } // nzval
56+
);
57+
58+
ClarabelFloat b[4] = { 1.0, 1.0, 1.0, 1.0 };
59+
60+
ClarabelSupportedConeT cones[1] = { ClarabelNonnegativeConeT(4) };
61+
62+
// Settings
63+
ClarabelDefaultSettings settings = clarabel_DefaultSettings_default();
64+
settings.equilibrate_enable = true;
65+
settings.equilibrate_max_iter = 50;
66+
67+
// Build solver
68+
ClarabelDefaultSolver *solver = clarabel_DefaultSolver_new(
69+
&P, // P
70+
q, // q
71+
&A, // A
72+
b, // b
73+
1, // n_cones
74+
cones, &settings
75+
);
76+
77+
78+
// configure a custom callback function
79+
CallbackData userdata = {-1};
80+
clarabel_DefaultSolver_set_termination_callback(solver,custom_callback,&userdata);
81+
82+
// Solve
83+
clarabel_DefaultSolver_solve(solver);
84+
85+
// turn off the callback
86+
clarabel_DefaultSolver_unset_termination_callback(solver);
87+
88+
// Solve again
89+
clarabel_DefaultSolver_solve(solver);
90+
91+
// Get solution
92+
ClarabelDefaultSolution solution = clarabel_DefaultSolver_solution(solver);
93+
print_solution(&solution);
94+
95+
// Free the solver
96+
clarabel_DefaultSolver_free(solver);
97+
98+
return 0;
99+
}
100+

examples/cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(CPP_EXAMPLES
1010
example_json
1111
example_print_stream
1212
example_callback
13+
example_callback_with_state
1314
example_faer
1415
example_pardiso_mkl
1516
)

examples/cpp/example_callback.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace clarabel;
88
using namespace std;
99
using namespace Eigen;
1010

11-
int custom_callback(DefaultInfo<double> &info)
11+
int custom_callback(DefaultInfo<double> &info, void* _userdata)
1212
{
1313
// This function is called at each iteration of the solver.
1414
// You can use it to monitor the progress of the solver or
@@ -67,7 +67,7 @@ int main()
6767
DefaultSolver<double> solver(P, q, A, b, cones, settings);
6868

6969
// configure a custom callback function
70-
solver.set_termination_callback(custom_callback);
70+
solver.set_termination_callback(custom_callback, nullptr);
7171

7272
// Solve
7373
solver.solve();
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "utils.h"
2+
3+
#include <Clarabel>
4+
#include <Eigen/Eigen>
5+
#include <vector>
6+
7+
using namespace clarabel;
8+
using namespace std;
9+
using namespace Eigen;
10+
11+
12+
typedef struct
13+
{
14+
int count;
15+
} CallbackData;
16+
17+
18+
int custom_callback(DefaultInfo<double> &info, void* userdata)
19+
{
20+
// Cast the userdata pointer back to our struct type
21+
CallbackData* data = (CallbackData*)userdata;
22+
23+
// Access and modify the state
24+
data->count++;
25+
26+
// Return 0 to continue. Anything else to stop.
27+
if (data->count < 3) {
28+
printf("tick\n");
29+
return 0; //continue
30+
}
31+
else {
32+
printf("BOOM!\n");
33+
return 1; // stop
34+
}
35+
}
36+
37+
int main()
38+
{
39+
MatrixXd P_dense = MatrixXd::Zero(2, 2);
40+
SparseMatrix<double> P = P_dense.sparseView();
41+
P.makeCompressed();
42+
43+
Vector<double, 2> q = {1.0, -1.0};
44+
45+
// a 2-d box constraint, separated into 4 inequalities.
46+
// A = [I; -I]
47+
MatrixXd A_dense(4, 2);
48+
A_dense <<
49+
1., 0.,
50+
0., 1.,
51+
-1., 0.,
52+
0., -1.;
53+
54+
SparseMatrix<double> A = A_dense.sparseView();
55+
A.makeCompressed();
56+
57+
Vector<double, 4> b = { 1.0, 1.0, 1.0, 1.0 };
58+
59+
vector<SupportedConeT<double>> cones
60+
{
61+
NonnegativeConeT<double>(4),
62+
// {.tag = SupportedConeT<double>::Tag::NonnegativeConeT, .nonnegative_cone_t = {._0 = 4 }}
63+
};
64+
65+
// Settings
66+
DefaultSettings<double> settings = DefaultSettingsBuilder<double>::default_settings()
67+
.equilibrate_enable(true)
68+
.equilibrate_max_iter(50)
69+
.build();
70+
71+
// Build solver
72+
DefaultSolver<double> solver(P, q, A, b, cones, settings);
73+
74+
75+
76+
// configure a custom callback function
77+
CallbackData userdata = {-1};
78+
solver.set_termination_callback(custom_callback, &userdata);
79+
80+
// Solve
81+
solver.solve();
82+
83+
// turn off the callback
84+
solver.unset_termination_callback();
85+
86+
// Solve again
87+
solver.solve();
88+
89+
90+
// Get solution
91+
DefaultSolution<double> solution = solver.solution();
92+
utils::print_solution(solution);
93+
94+
return 0;
95+
}

include/c/DefaultSolver.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,24 +212,24 @@ static inline ClarabelDefaultInfo clarabel_DefaultSolver_info(ClarabelDefaultSol
212212
}
213213

214214
// DefaultSolver callbacks
215-
typedef int (*ClarabelCallbackFcn_f32)(ClarabelDefaultInfo_f32 *info);
216-
typedef int (*ClarabelCallbackFcn_f64)(ClarabelDefaultInfo_f64 *info);
215+
typedef int (*ClarabelCallbackFcn_f32)(ClarabelDefaultInfo_f32 *info, void* userdata);
216+
typedef int (*ClarabelCallbackFcn_f64)(ClarabelDefaultInfo_f64 *info, void* userdata);
217217

218218
#ifdef CLARABEL_USE_FLOAT
219219
typedef ClarabelCallbackFcn_f32 ClarabelCallbackFcn;
220220
#else
221221
typedef ClarabelCallbackFcn_f64 ClarabelCallbackFcn;
222222
#endif
223223

224-
void clarabel_DefaultSolver_f64_set_termination_callback(ClarabelDefaultSolver_f64 *solver, ClarabelCallbackFcn_f64 callback);
225-
void clarabel_DefaultSolver_f32_set_termination_callback(ClarabelDefaultSolver_f32 *solver, ClarabelCallbackFcn_f32 callback);
224+
void clarabel_DefaultSolver_f64_set_termination_callback(ClarabelDefaultSolver_f64 *solver, ClarabelCallbackFcn_f64 callback, void* userdata);
225+
void clarabel_DefaultSolver_f32_set_termination_callback(ClarabelDefaultSolver_f32 *solver, ClarabelCallbackFcn_f32 callback, void* userdata);
226226

227-
static inline void clarabel_DefaultSolver_set_termination_callback(ClarabelDefaultSolver *solver, ClarabelCallbackFcn callback)
227+
static inline void clarabel_DefaultSolver_set_termination_callback(ClarabelDefaultSolver *solver, ClarabelCallbackFcn callback, void* userdata)
228228
{
229229
#ifdef CLARABEL_USE_FLOAT
230-
clarabel_DefaultSolver_f32_set_termination_callback(solver,callback);
230+
clarabel_DefaultSolver_f32_set_termination_callback(solver,callback, userdata);
231231
#else
232-
clarabel_DefaultSolver_f64_set_termination_callback(solver,callback);
232+
clarabel_DefaultSolver_f64_set_termination_callback(solver,callback, userdata);
233233
#endif
234234
}
235235

include/cpp/DefaultSolver.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class DefaultSolver
122122
// termination callbacks
123123
// -------------------------------
124124
void set_termination_callback(
125-
int (*callback)(clarabel::DefaultInfo<T>&));
125+
int (*callback)(clarabel::DefaultInfo<T>&,void*), void* userdata);
126126

127127
void unset_termination_callback();
128128

@@ -233,8 +233,8 @@ DefaultInfo<double> clarabel_DefaultSolver_f64_info(RustDefaultSolverHandle_f64
233233

234234
DefaultInfo<float> clarabel_DefaultSolver_f32_info(RustDefaultSolverHandle_f32 solver);
235235

236-
void clarabel_DefaultSolver_f64_set_termination_callback(RustDefaultSolverHandle_f64 solver, int (*callback)(DefaultInfo<double>&));
237-
void clarabel_DefaultSolver_f32_set_termination_callback(RustDefaultSolverHandle_f32 solver, int (*callback)(DefaultInfo<float>&));
236+
void clarabel_DefaultSolver_f64_set_termination_callback(RustDefaultSolverHandle_f64 solver, int (*callback)(DefaultInfo<double>& ,void*),void* userdata);
237+
void clarabel_DefaultSolver_f32_set_termination_callback(RustDefaultSolverHandle_f32 solver, int (*callback)(DefaultInfo<float>&, void*),void* userdata);
238238
void clarabel_DefaultSolver_f64_unset_termination_callback(RustDefaultSolverHandle_f64 solver);
239239
void clarabel_DefaultSolver_f32_unset_termination_callback(RustDefaultSolverHandle_f32 solver);
240240

@@ -392,13 +392,13 @@ inline DefaultInfo<float> DefaultSolver<float>::info() const
392392

393393

394394
template<>
395-
inline void DefaultSolver<double>::set_termination_callback(int (*callback)(DefaultInfo<double>&)) {
396-
clarabel_DefaultSolver_f64_set_termination_callback(this->handle, callback);
395+
inline void DefaultSolver<double>::set_termination_callback(int (*callback)(DefaultInfo<double>&, void*), void* userdata) {
396+
clarabel_DefaultSolver_f64_set_termination_callback(this->handle, callback,userdata);
397397
}
398398

399399
template<>
400-
inline void DefaultSolver<float>::set_termination_callback(int (*callback)(DefaultInfo<float>&)) {
401-
clarabel_DefaultSolver_f32_set_termination_callback(this->handle, callback);
400+
inline void DefaultSolver<float>::set_termination_callback(int (*callback)(DefaultInfo<float>&, void*), void* userdata) {
401+
clarabel_DefaultSolver_f32_set_termination_callback(this->handle, callback, userdata);
402402
}
403403

404404
template<>

rust_wrapper/src/solver/implementations/default/callbacks.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,40 @@ use clarabel::algebra::FloatT;
77
use clarabel::solver::{self as lib};
88
use std::ffi::{c_int, c_void};
99

10-
pub(crate) type CallbackFcnFFI<T> = extern "C" fn(info: *const ClarabelDefaultInfo<T>) -> c_int;
10+
pub(crate) type CallbackFcnFFI<T> =
11+
extern "C" fn(info: *const ClarabelDefaultInfo<T>, userdata: *mut std::ffi::c_void) -> c_int;
1112
pub type ClarabelCallbackFcn_f32 = CallbackFcnFFI<f32>;
1213
pub type ClarabelCallbackFcn_f64 = CallbackFcnFFI<f64>;
1314

1415
/// Set the termination callback
1516
fn _internal_DefaultSolver_set_termination_callback<T: FloatT>(
1617
solver: *mut c_void,
1718
callback: CallbackFcnFFI<T>,
19+
userdata: *mut std::ffi::c_void,
1820
) {
1921
// Recover the solver object from the opaque pointer
2022
let solver = unsafe { &mut *(solver as *mut lib::DefaultSolver<T>) };
2123

2224
// Set the termination callback
23-
solver.set_termination_callback_c(callback);
25+
solver.set_termination_callback_c(callback, userdata);
2426
}
2527

2628
#[no_mangle]
2729
pub extern "C" fn clarabel_DefaultSolver_f64_set_termination_callback(
2830
solver: *mut ClarabelDefaultSolver_f64,
2931
callback: ClarabelCallbackFcn_f64,
32+
userdata: *mut std::ffi::c_void,
3033
) {
31-
_internal_DefaultSolver_set_termination_callback::<f64>(solver, callback)
34+
_internal_DefaultSolver_set_termination_callback::<f64>(solver, callback, userdata)
3235
}
3336

3437
#[no_mangle]
3538
pub extern "C" fn clarabel_DefaultSolver_f32_set_termination_callback(
3639
solver: *mut ClarabelDefaultSolver_f32,
3740
callback: ClarabelCallbackFcn_f32,
41+
userdata: *mut std::ffi::c_void,
3842
) {
39-
_internal_DefaultSolver_set_termination_callback::<f32>(solver, callback)
43+
_internal_DefaultSolver_set_termination_callback::<f32>(solver, callback, userdata)
4044
}
4145

4246
/// Turn off the termination callback

0 commit comments

Comments
 (0)