33log_prefix = 'emlearn_logreg:'
44
55
6- def _make_workspace_triplet (n_classes ):
7- return (
8- array .array ('f' , [0.0 ] * n_classes ),
9- array .array ('f' , [0.0 ] * n_classes ),
10- array .array ('f' , [0.0 ] * n_classes ),
11- )
6+ def _make_buffer (n ):
7+ return array .array ('f' , [0.0 ] * n )
128
139
14- def _make_workspace_pair (n_classes ):
15- return (
16- array .array ('f' , [0.0 ] * n_classes ),
17- array .array ('f' , [0.0 ] * n_classes ),
18- )
19-
20-
21- def _make_predict_buffers (n_classes ):
22- return (
23- array .array ('f' , [0.0 ] * n_classes ),
24- array .array ('f' , [0.0 ] * n_classes ),
25- )
2610
2711def train (model , X_train , y_train ,
2812 max_iterations = 200 ,
@@ -31,12 +15,13 @@ def train(model, X_train, y_train,
3115 divergence_factor = 2.0 ,
3216 score_limit = None ,
3317 verbose = 0 ,
34- batch_size = None ,
3518 ):
36- """Mini-batch training loop for logistic regression.
19+ """Full-dataset training loop for logistic regression."""
20+ if max_iterations <= 0 :
21+ raise ValueError ('max_iterations must be positive' )
22+ if check_interval <= 0 :
23+ raise ValueError ('check_interval must be positive' )
3724
38- Copies data into a reusable buffer when mini-batching to limit peak memory.
39- """
4025 n_features = model .get_n_features ()
4126 n_classes = model .get_n_classes ()
4227 if len (X_train ) % n_features != 0 :
@@ -48,53 +33,27 @@ def train(model, X_train, y_train,
4833 if n_samples == 0 :
4934 raise ValueError ('y_train is empty' )
5035
51- if batch_size is None or batch_size <= 0 or batch_size > n_samples :
52- batch_size = n_samples
53-
54- logits_buf , probs_buf , bias_buf = _make_workspace_triplet (n_classes )
55- score_logits , score_probs = _make_workspace_pair (n_classes )
56- predict_logits , predict_probs = _make_predict_buffers (n_classes )
36+ logits_buf = _make_buffer (n_classes )
37+ probs_buf = _make_buffer (n_classes )
38+ bias_buf = _make_buffer (n_classes )
39+ score_logits = _make_buffer (n_classes )
40+ score_probs = _make_buffer (n_classes )
5741
5842 prev_loss = None
5943 final_loss = float ('inf' )
6044 iterations_completed = 0
6145
62- use_batches = batch_size < n_samples
63- full_X_view = memoryview (X_train )
64- full_y_view = memoryview (y_train )
65- if use_batches :
66- batch_X = array .array ('f' , [0.0 ] * (batch_size * n_features ))
67- batch_y = array .array ('f' , [0.0 ] * (batch_size * n_classes ))
68- batch_X_view = memoryview (batch_X )
69- batch_y_view = memoryview (batch_y )
70- else :
71- batch_X_view = full_X_view
72- batch_y_view = full_y_view
46+ X_view = memoryview (X_train )
47+ y_view = memoryview (y_train )
7348
7449 for _ in range (max_iterations ):
7550 iterations_completed += 1
76- if use_batches :
77- for start in range (0 , n_samples , batch_size ):
78- count = min (batch_size , n_samples - start )
79- base_feature = start * n_features
80- base_target = start * n_classes
81- # Copy features for current batch
82- end_f = base_feature + count * n_features
83- batch_X [:count * n_features ] = X_train [base_feature :end_f ]
84- # Copy targets
85- end_t = base_target + count * n_classes
86- batch_y [:count * n_classes ] = y_train [base_target :end_t ]
87-
88- X_slice = batch_X_view [:count * n_features ]
89- y_slice = batch_y_view [:count * n_classes ]
90- model .step (X_slice , y_slice , logits_buf , probs_buf , bias_buf )
91- else :
92- model .step (batch_X_view , batch_y_view , logits_buf , probs_buf , bias_buf )
51+ model .step (X_view , y_view , logits_buf , probs_buf , bias_buf )
9352
9453 if iterations_completed % check_interval != 0 :
9554 continue
9655
97- current_loss = model .score_logloss (full_X_view , full_y_view , score_logits , score_probs )
56+ current_loss = model .score_logloss (X_view , y_view , score_logits , score_probs )
9857 final_loss = current_loss
9958 change = float ('inf' ) if prev_loss is None else abs (prev_loss - current_loss )
10059
@@ -123,7 +82,116 @@ def train(model, X_train, y_train,
12382 prev_loss = current_loss
12483
12584 if final_loss == float ('inf' ):
126- final_loss = model .score_logloss (full_X_view , full_y_view , score_logits , score_probs )
85+ final_loss = model .score_logloss (X_view , y_view , score_logits , score_probs )
86+
87+ return iterations_completed , final_loss
88+
89+
90+ def train_batches (model ,
91+ batch_iter_factory ,
92+ max_iterations = 200 ,
93+ tolerance = 1e-4 ,
94+ check_interval = 5 ,
95+ divergence_factor = 2.0 ,
96+ score_limit = None ,
97+ verbose = 0 ,
98+ score_batches = None ,
99+ ):
100+ """Train logistic regression model using externally provided batches.
101+
102+ batch_iter_factory must be a callable that returns a fresh iterator for each
103+ epoch. Each iterator should yield tuples of (X_batch, y_batch) where both are
104+ float32 arrays compatible with model.step(). y_batch must be one-hot encoded.
105+
106+ score_batches is an optional callable taking the model and returning the
107+ average log-loss over the data (computed however the caller prefers). When
108+ provided, it is used for convergence checking.
109+ """
110+ if not callable (batch_iter_factory ):
111+ raise ValueError ('batch_iter_factory must be callable' )
112+ if max_iterations <= 0 :
113+ raise ValueError ('max_iterations must be positive' )
114+ if check_interval <= 0 :
115+ raise ValueError ('check_interval must be positive' )
116+ if score_batches is not None and not callable (score_batches ):
117+ raise ValueError ('score_batches must be callable' )
118+
119+ n_features = model .get_n_features ()
120+ n_classes = model .get_n_classes ()
121+
122+ logits_buf = _make_buffer (n_classes )
123+ probs_buf = _make_buffer (n_classes )
124+ bias_buf = _make_buffer (n_classes )
125+
126+ prev_loss = None
127+ final_loss = float ('inf' )
128+ iterations_completed = 0
129+
130+ for _ in range (max_iterations ):
131+ iterations_completed += 1
132+ batches = batch_iter_factory ()
133+ try :
134+ batch_iter = iter (batches )
135+ except TypeError :
136+ raise ValueError ('batch iterator must be iterable' )
137+
138+ batches_processed = 0
139+
140+ for batch in batch_iter :
141+ batches_processed += 1
142+ try :
143+ X_batch , y_batch = batch
144+ except Exception as exc :
145+ raise ValueError ('each batch must unpack into (X_batch, y_batch)' ) from exc
146+
147+ if len (X_batch ) == 0 :
148+ continue
149+ if len (X_batch ) % n_features != 0 :
150+ raise ValueError ('X_batch size mismatch with n_features' )
151+ n_samples = len (X_batch ) // n_features
152+ if len (y_batch ) != n_samples * n_classes :
153+ raise ValueError ('y_batch must be one-hot encoded (len = n_samples * n_classes)' )
154+
155+ model .step (X_batch , y_batch , logits_buf , probs_buf , bias_buf )
156+
157+ if batches_processed == 0 :
158+ raise ValueError ('batch iterator produced no batches' )
159+
160+ if iterations_completed % check_interval != 0 :
161+ continue
162+ if score_batches is None :
163+ continue
164+
165+ current_loss = float (score_batches (model ))
166+ final_loss = current_loss
167+ change = float ('inf' ) if prev_loss is None else abs (prev_loss - current_loss )
168+
169+ if verbose >= 2 :
170+ print (log_prefix , f'Iteration { iterations_completed } loss={ current_loss } ' )
171+
172+ converged = change < tolerance and iterations_completed > check_interval * 2
173+
174+ if score_limit is not None and current_loss <= score_limit :
175+ converged = True
176+
177+ diverged = not (current_loss == current_loss )
178+ if not diverged and prev_loss is not None :
179+ diverged = current_loss > prev_loss * divergence_factor
180+
181+ if converged :
182+ if verbose >= 1 :
183+ print (log_prefix , f"Converged at iteration { iterations_completed } " )
184+ break
185+
186+ if diverged :
187+ if verbose >= 1 :
188+ print (log_prefix , f"Diverged at iteration { iterations_completed } " )
189+ break
190+
191+ prev_loss = current_loss
192+
193+ if score_batches is not None and final_loss == float ('inf' ):
194+ final_loss = float (score_batches (model ))
127195
128196 return iterations_completed , final_loss
129197
0 commit comments