Skip to content

Commit ec77c73

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 33ef08b commit ec77c73

3 files changed

Lines changed: 14 additions & 23 deletions

File tree

xcsf/pybind_callback.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@ extern "C" {
3838
*/
3939
class Callback
4040
{
41-
public:
41+
public:
4242
virtual ~Callback() {}
4343

44-
virtual bool
45-
run(struct XCSF *xcsf, py::dict metrics) = 0;
44+
virtual bool run(struct XCSF * xcsf, py::dict metrics) = 0;
4645

47-
virtual void
48-
finish(struct XCSF *xcsf) = 0;
46+
virtual void finish(struct XCSF * xcsf) = 0;
4947
};

xcsf/pybind_callback_checkpoint.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ extern "C" {
4141
*/
4242
class CheckpointCallback : public Callback
4343
{
44-
public:
44+
public:
4545
/**
4646
* @brief Constructs a new checkpoint callback.
4747
* @param [in] monitor Name of the metric to monitor: {"train", "val"}.
@@ -74,8 +74,7 @@ class CheckpointCallback : public Callback
7474
* @brief Saves the state of XCSF.
7575
* @param [in] xcsf The XCSF data structure.
7676
*/
77-
void
78-
save(struct XCSF *xcsf)
77+
void save(struct XCSF * xcsf)
7978
{
8079
xcsf_save(xcsf, filename.c_str());
8180
std::ostringstream status;
@@ -90,8 +89,7 @@ class CheckpointCallback : public Callback
9089
* @param [in] metrics Dictionary of performance metrics.
9190
* @return Whether to terminate training.
9291
*/
93-
bool
94-
run(struct XCSF *xcsf, py::dict metrics) override
92+
bool run(struct XCSF * xcsf, py::dict metrics) override
9593
{
9694
py::list data = metrics[monitor];
9795
py::list trials = metrics["trials"];
@@ -113,15 +111,14 @@ class CheckpointCallback : public Callback
113111
* @brief Executes any tasks at the end of fitting.
114112
* @param [in] xcsf The XCSF data structure.
115113
*/
116-
void
117-
finish(struct XCSF *xcsf) override
114+
void finish(struct XCSF * xcsf) override
118115
{
119116
if (!save_best_only) {
120117
save(xcsf);
121118
}
122119
}
123120

124-
private:
121+
private:
125122
py::str monitor; //!< Name of the metric to monitor
126123
std::string filename; //!< Name of the file to save XCSF
127124
bool save_best_only; //!< Whether to only save the best population

xcsf/pybind_callback_earlystop.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ extern "C" {
4141
*/
4242
class EarlyStoppingCallback : public Callback
4343
{
44-
public:
44+
public:
4545
/**
4646
* @brief Constructs a new early stopping callback.
4747
* @param [in] monitor Name of the metric to monitor: {"train", "val"}.
@@ -82,8 +82,7 @@ class EarlyStoppingCallback : public Callback
8282
* @brief Stores best XCSF population in memory.
8383
* @param [in] xcsf The XCSF data structure.
8484
*/
85-
void
86-
store(struct XCSF *xcsf)
85+
void store(struct XCSF * xcsf)
8786
{
8887
do_restore = true;
8988
xcsf_store_pset(xcsf);
@@ -100,8 +99,7 @@ class EarlyStoppingCallback : public Callback
10099
* @brief Retrieves best XCSF population in memory.
101100
* @param [in] xcsf The XCSF data structure.
102101
*/
103-
void
104-
retrieve(struct XCSF *xcsf)
102+
void retrieve(struct XCSF * xcsf)
105103
{
106104
do_restore = false;
107105
xcsf_retrieve_pset(xcsf);
@@ -121,8 +119,7 @@ class EarlyStoppingCallback : public Callback
121119
* @param [in] metrics Dictionary of performance metrics.
122120
* @return whether early stopping criteria has been met.
123121
*/
124-
bool
125-
run(struct XCSF *xcsf, py::dict metrics) override
122+
bool run(struct XCSF * xcsf, py::dict metrics) override
126123
{
127124
py::list data = metrics[monitor];
128125
py::list trials = metrics["trials"];
@@ -156,15 +153,14 @@ class EarlyStoppingCallback : public Callback
156153
* @brief Executes any tasks at the end of fitting.
157154
* @param [in] xcsf The XCSF data structure.
158155
*/
159-
void
160-
finish(struct XCSF *xcsf) override
156+
void finish(struct XCSF * xcsf) override
161157
{
162158
if (restore && do_restore) {
163159
retrieve(xcsf);
164160
}
165161
}
166162

167-
private:
163+
private:
168164
py::str monitor; //!< Name of the metric to monitor
169165
int patience; //!< Stop training after this many trials with no improvement
170166
bool restore; //!< Whether to restore the best population

0 commit comments

Comments
 (0)