Skip to content

Commit 58c1aec

Browse files
committed
logreg: Some progress on multi-class
1 parent a21a57d commit 58c1aec

5 files changed

Lines changed: 681 additions & 215 deletions

File tree

src/emlearn_logreg/eml_logreg.c

Lines changed: 158 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <math.h>
22
#include <string.h>
3+
#include <float.h>
4+
#include <stdbool.h>
35

46
static float logf_compat(float x) {
57
if (x <= 0.0f) {
@@ -55,13 +57,45 @@ static float expf_compat(float x) {
5557
typedef struct {
5658
float *weights;
5759
float *weight_gradients;
58-
float bias;
60+
float *biases;
5961
uint16_t n_features;
62+
uint16_t n_classes;
6063
float learning_rate;
6164
float lambda_l2;
6265
float lambda_l1;
6366
} logreg_model_t;
6467

68+
typedef struct {
69+
float *logits;
70+
float *probabilities;
71+
float *bias_gradients;
72+
uint16_t logits_size;
73+
uint16_t probabilities_size;
74+
uint16_t bias_gradients_size;
75+
} logreg_workspace_t;
76+
77+
static bool logreg_workspace_validate(const logreg_workspace_t *workspace,
78+
uint16_t required_size,
79+
bool need_bias) {
80+
if (workspace == NULL) {
81+
return false;
82+
}
83+
if (workspace->logits == NULL || workspace->probabilities == NULL) {
84+
return false;
85+
}
86+
if (workspace->logits_size < required_size ||
87+
workspace->probabilities_size < required_size) {
88+
return false;
89+
}
90+
if (need_bias) {
91+
if (workspace->bias_gradients == NULL ||
92+
workspace->bias_gradients_size < required_size) {
93+
return false;
94+
}
95+
}
96+
return true;
97+
}
98+
6599
static float soft_threshold(float value, float threshold) {
66100
if (value > threshold) {
67101
return value - threshold;
@@ -71,105 +105,176 @@ static float soft_threshold(float value, float threshold) {
71105
return 0.0f;
72106
}
73107

74-
75-
static float sigmoidf(float x) {
76-
if (x > 10.0f) {
77-
x = 10.0f;
78-
} else if (x < -10.0f) {
79-
x = -10.0f;
108+
// logits buffer must be size n_classes
109+
static void logreg_predict_scores(const logreg_model_t *model,
110+
const float *features,
111+
float *scores) {
112+
const uint16_t n_classes = model->n_classes;
113+
const uint16_t n_features = model->n_features;
114+
for (uint16_t cls = 0; cls < n_classes; cls++) {
115+
const float *weights_cls = &model->weights[cls * n_features];
116+
float logit = model->biases[cls];
117+
for (uint16_t feat = 0; feat < n_features; feat++) {
118+
logit += weights_cls[feat] * features[feat];
119+
}
120+
scores[cls] = logit;
80121
}
81-
float ex = expf_compat(x);
82-
return ex / (1.0f + ex);
83122
}
84123

85-
static float logreg_predict_proba(const logreg_model_t *model, const float *features) {
86-
float logit = model->bias;
87-
for (uint16_t i = 0; i < model->n_features; i++) {
88-
logit += model->weights[i] * features[i];
124+
// logits and probabilities buffers must be size n_classes
125+
static void logreg_softmax(const float *logits,
126+
uint16_t n_classes,
127+
float *probabilities) {
128+
float max_logit = -FLT_MAX;
129+
for (uint16_t cls = 0; cls < n_classes; cls++) {
130+
if (logits[cls] > max_logit) {
131+
max_logit = logits[cls];
132+
}
133+
}
134+
135+
float sum = 0.0f;
136+
for (uint16_t cls = 0; cls < n_classes; cls++) {
137+
float value = expf_compat(logits[cls] - max_logit);
138+
probabilities[cls] = value;
139+
sum += value;
140+
}
141+
142+
const float inv_sum = 1.0f / sum;
143+
for (uint16_t cls = 0; cls < n_classes; cls++) {
144+
probabilities[cls] *= inv_sum;
89145
}
90-
return sigmoidf(logit);
91146
}
92147

93-
static void logreg_iterate(logreg_model_t *model,
94-
const float *X,
95-
const float *y,
96-
uint16_t n_samples) {
97-
if (n_samples == 0) {
98-
return;
148+
static void logreg_predict_softmax(const logreg_model_t *model,
149+
const float *features,
150+
float *probabilities,
151+
float *logits) {
152+
const uint16_t n_classes = model->n_classes;
153+
logreg_predict_scores(model, features, logits);
154+
logreg_softmax(logits, n_classes, probabilities);
155+
}
156+
157+
bool logreg_iterate(logreg_model_t *model,
158+
const float *X,
159+
const float *y,
160+
uint16_t n_samples,
161+
logreg_workspace_t *workspace) {
162+
if (n_samples == 0 || model->n_classes == 0) {
163+
return true;
164+
}
165+
166+
if (!logreg_workspace_validate(workspace, model->n_classes, true)) {
167+
return false;
99168
}
100169

101170
const uint16_t n_features = model->n_features;
171+
const uint16_t n_classes = model->n_classes;
172+
173+
memset(model->weight_gradients, 0, n_classes * n_features * sizeof(float));
174+
memset(workspace->bias_gradients, 0, n_classes * sizeof(float));
102175

103-
memset(model->weight_gradients, 0, n_features * sizeof(float));
104-
float bias_gradient = 0.0f;
176+
float *logits_ptr = workspace->logits;
177+
float *probs_ptr = workspace->probabilities;
105178

106179
for (uint16_t sample = 0; sample < n_samples; sample++) {
107180
const float *features = &X[sample * n_features];
108-
const float target = y[sample];
109-
const float prediction = logreg_predict_proba(model, features);
110-
const float error = prediction - target;
181+
const float *target = &y[sample * n_classes];
182+
logreg_predict_softmax(model, features, probs_ptr, logits_ptr);
111183

112-
bias_gradient += error;
113-
for (uint16_t feat = 0; feat < n_features; feat++) {
114-
model->weight_gradients[feat] += error * features[feat];
184+
for (uint16_t cls = 0; cls < n_classes; cls++) {
185+
const float error = probs_ptr[cls] - target[cls];
186+
workspace->bias_gradients[cls] += error;
187+
float *grad_weights = &model->weight_gradients[cls * n_features];
188+
for (uint16_t feat = 0; feat < n_features; feat++) {
189+
grad_weights[feat] += error * features[feat];
190+
}
115191
}
116192
}
117193

118194
const float inv_samples = 1.0f / (float)n_samples;
119-
bias_gradient *= inv_samples;
120-
for (uint16_t feat = 0; feat < n_features; feat++) {
121-
model->weight_gradients[feat] *= inv_samples;
195+
for (uint16_t cls = 0; cls < n_classes; cls++) {
196+
workspace->bias_gradients[cls] *= inv_samples;
197+
float *grad_weights = &model->weight_gradients[cls * n_features];
198+
for (uint16_t feat = 0; feat < n_features; feat++) {
199+
grad_weights[feat] *= inv_samples;
200+
}
122201
}
123202

124203
const float lr = model->learning_rate;
125204
const float l2 = model->lambda_l2;
126205
const float l1 = model->lambda_l1;
127206
const float l1_threshold = lr * l1;
128207

129-
for (uint16_t feat = 0; feat < n_features; feat++) {
130-
float grad = model->weight_gradients[feat] + l2 * model->weights[feat];
131-
float updated = model->weights[feat] - lr * grad;
132-
model->weights[feat] = soft_threshold(updated, l1_threshold);
208+
for (uint16_t cls = 0; cls < n_classes; cls++) {
209+
float *weights_cls = &model->weights[cls * n_features];
210+
float *grad_weights = &model->weight_gradients[cls * n_features];
211+
for (uint16_t feat = 0; feat < n_features; feat++) {
212+
float grad = grad_weights[feat] + l2 * weights_cls[feat];
213+
float updated = weights_cls[feat] - lr * grad;
214+
weights_cls[feat] = soft_threshold(updated, l1_threshold);
215+
}
216+
model->biases[cls] -= lr * workspace->bias_gradients[cls];
133217
}
134218

135-
model->bias -= lr * bias_gradient;
219+
return true;
136220
}
137221

138-
static float logreg_logloss(const logreg_model_t *model,
139-
const float *X,
140-
const float *y,
141-
uint16_t n_samples) {
142-
if (n_samples == 0) {
143-
return 0.0f;
222+
bool logreg_logloss(const logreg_model_t *model,
223+
const float *X,
224+
const float *y,
225+
uint16_t n_samples,
226+
logreg_workspace_t *workspace,
227+
float *loss_out) {
228+
if (loss_out == NULL) {
229+
return false;
230+
}
231+
232+
if (n_samples == 0 || model->n_classes == 0) {
233+
*loss_out = 0.0f;
234+
return true;
235+
}
236+
237+
if (!logreg_workspace_validate(workspace, model->n_classes, false)) {
238+
*loss_out = 0.0f;
239+
return false;
144240
}
145241

146242
const float eps = 1e-7f;
147243
float loss = 0.0f;
244+
const uint16_t n_features = model->n_features;
245+
const uint16_t n_classes = model->n_classes;
246+
float *logits_ptr = workspace->logits;
247+
float *probs_ptr = workspace->probabilities;
148248

149249
for (uint16_t sample = 0; sample < n_samples; sample++) {
150-
const float *features = &X[sample * model->n_features];
151-
float prediction = logreg_predict_proba(model, features);
152-
if (prediction < eps) {
153-
prediction = eps;
154-
} else if (prediction > 1.0f - eps) {
155-
prediction = 1.0f - eps;
250+
const float *features = &X[sample * n_features];
251+
const float *target = &y[sample * n_classes];
252+
logreg_predict_softmax(model, features, probs_ptr, logits_ptr);
253+
for (uint16_t cls = 0; cls < n_classes; cls++) {
254+
float prediction = probs_ptr[cls];
255+
if (prediction < eps) {
256+
prediction = eps;
257+
} else if (prediction > 1.0f - eps) {
258+
prediction = 1.0f - eps;
259+
}
260+
loss -= target[cls] * logf_compat(prediction);
156261
}
157-
const float target = y[sample];
158-
loss -= target * logf_compat(prediction) + (1.0f - target) * logf_compat(1.0f - prediction);
159262
}
160263

161264
loss /= (float)n_samples;
162265

163266
if (model->lambda_l2 > 0.0f || model->lambda_l1 > 0.0f) {
164267
float l2_term = 0.0f;
165268
float l1_term = 0.0f;
166-
for (uint16_t feat = 0; feat < model->n_features; feat++) {
167-
const float weight = model->weights[feat];
269+
const uint32_t total_weights = (uint32_t)n_features * (uint32_t)n_classes;
270+
for (uint32_t idx = 0; idx < total_weights; idx++) {
271+
const float weight = model->weights[idx];
168272
l2_term += weight * weight;
169273
l1_term += fabsf(weight);
170274
}
171275
loss += 0.5f * model->lambda_l2 * l2_term + model->lambda_l1 * l1_term;
172276
}
173277

174-
return loss;
278+
*loss_out = loss;
279+
return true;
175280
}

0 commit comments

Comments
 (0)