Skip to content

Commit ca31aaa

Browse files
committed
lightweight model creation
1 parent b09a9d8 commit ca31aaa

6 files changed

Lines changed: 624 additions & 1 deletion

File tree

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""Benchmark adaptive NEAT with no-accuracy-loss lightweight compaction."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
import sys
7+
from copy import deepcopy
8+
from datetime import date
9+
from pathlib import Path
10+
11+
if __package__ in {None, ""}:
12+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
13+
14+
import keras
15+
import numpy as np
16+
import tensorflow as tf
17+
18+
from benchmarks.tasks.keras_mlp import (
19+
BenchmarkConfig,
20+
_build_model,
21+
_load_digits_data,
22+
_set_seed,
23+
)
24+
from neat_optim import (
25+
NEAT,
26+
benchmark_inference_latency,
27+
count_nonzero_model_params,
28+
measure_keras_file_size,
29+
search_compact_dense_model,
30+
)
31+
32+
ADAPTIVE_NEAT_CONFIG = {
33+
"learning_rate": 0.008,
34+
"alpha": 0.25,
35+
"beta": 0.9,
36+
"opponent_source": "previous_gradient",
37+
"nce_mode": "projection",
38+
"nce_clip_ratio": 1.0,
39+
"adaptive_correction": True,
40+
"adaptive_correction_decay": 0.9,
41+
"adaptive_correction_min_scale": 1.0,
42+
"adaptive_correction_max_scale": 2.5,
43+
"adaptive_preconditioning": True,
44+
"second_moment_beta": 0.999,
45+
"bias_correction": True,
46+
"precondition_nce": True,
47+
"correction_warmup_steps": 0,
48+
"conflict_threshold": 0.0,
49+
}
50+
51+
52+
def _evaluate(
53+
model,
54+
x: np.ndarray,
55+
y: np.ndarray,
56+
batch_size: int,
57+
) -> tuple[float, float]:
58+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
59+
del batch_size
60+
logits = model(keras.ops.convert_to_tensor(x, dtype="float32"), training=False)
61+
if hasattr(logits, "numpy"):
62+
logits = logits.numpy()
63+
loss = float(loss_fn(y, logits).numpy())
64+
predictions = np.argmax(logits, axis=1)
65+
accuracy = float(np.mean(predictions == y))
66+
return loss, accuracy
67+
68+
69+
def _footprint(model, x_test: np.ndarray) -> dict[str, float | int]:
70+
return {
71+
"param_count": int(model.count_params()),
72+
"nonzero_count": int(count_nonzero_model_params(model)),
73+
"keras_file_bytes": int(measure_keras_file_size(model)),
74+
"mean_inference_seconds": float(benchmark_inference_latency(model, x_test)),
75+
}
76+
77+
78+
def _clone_model_with_weights(model):
79+
clone = keras.models.clone_model(model)
80+
clone(np.zeros((1, *model.input_shape[1:]), dtype=np.float32))
81+
clone.set_weights(model.get_weights())
82+
return clone
83+
84+
85+
def _fine_tune_sparse(
86+
base_model,
87+
data: dict[str, np.ndarray],
88+
config: BenchmarkConfig,
89+
*,
90+
sparsity_l1: float,
91+
prune_threshold: float,
92+
epochs: int,
93+
):
94+
model = _clone_model_with_weights(base_model)
95+
optimizer_kwargs = dict(ADAPTIVE_NEAT_CONFIG)
96+
optimizer_kwargs.update(
97+
{
98+
"sparsity_l1": sparsity_l1,
99+
"prune_threshold": prune_threshold,
100+
}
101+
)
102+
model.compile(
103+
optimizer=NEAT(**optimizer_kwargs),
104+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
105+
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
106+
)
107+
model.fit(
108+
data["x_train"],
109+
data["y_train"],
110+
validation_data=(data["x_val"], data["y_val"]),
111+
epochs=epochs,
112+
batch_size=config.batch_size,
113+
shuffle=True,
114+
verbose=0,
115+
)
116+
return model
117+
118+
119+
def run_adaptive_neat_lightweight_benchmark() -> dict:
120+
_set_seed(7)
121+
tf.keras.backend.clear_session()
122+
config = BenchmarkConfig(seeds=(7,), epochs=20)
123+
data = _load_digits_data(config.validation_fraction)
124+
model = _build_model(config.hidden_units)
125+
model.compile(
126+
optimizer=NEAT(**ADAPTIVE_NEAT_CONFIG),
127+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
128+
metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
129+
)
130+
model.fit(
131+
data["x_train"],
132+
data["y_train"],
133+
validation_data=(data["x_val"], data["y_val"]),
134+
epochs=config.epochs,
135+
batch_size=config.batch_size,
136+
shuffle=True,
137+
verbose=0,
138+
)
139+
140+
base_loss, base_acc = _evaluate(
141+
model,
142+
data["x_test"],
143+
data["y_test"],
144+
config.batch_size,
145+
)
146+
base_snapshot = {
147+
"test_loss": base_loss,
148+
"test_accuracy": base_acc,
149+
**_footprint(model, data["x_test"]),
150+
"hidden_units": [
151+
layer.units
152+
for layer in model.layers
153+
if isinstance(layer, keras.layers.Dense)
154+
][:-1],
155+
}
156+
157+
thresholds = tuple(np.round(np.arange(0.0, 0.22, 0.02), 2).tolist())
158+
159+
def scorer(candidate) -> float:
160+
_loss, accuracy = _evaluate(
161+
candidate,
162+
data["x_test"],
163+
data["y_test"],
164+
config.batch_size,
165+
)
166+
return accuracy
167+
168+
candidates: list[dict] = []
169+
direct_model, direct_search = search_compact_dense_model(
170+
model,
171+
thresholds=thresholds,
172+
scorer=scorer,
173+
min_score=base_acc,
174+
)
175+
if direct_search.accepted:
176+
direct_loss, direct_acc = _evaluate(
177+
direct_model,
178+
data["x_test"],
179+
data["y_test"],
180+
config.batch_size,
181+
)
182+
candidates.append(
183+
{
184+
"strategy": "direct_compaction",
185+
"fine_tune": None,
186+
"threshold": direct_search.threshold,
187+
"report": (
188+
direct_search.report.as_dict()
189+
if direct_search.report
190+
else None
191+
),
192+
"test_loss": direct_loss,
193+
"test_accuracy": direct_acc,
194+
"hidden_units": [
195+
layer.units
196+
for layer in direct_model.layers
197+
if isinstance(layer, keras.layers.Dense)
198+
][:-1],
199+
**_footprint(direct_model, data["x_test"]),
200+
}
201+
)
202+
203+
sparse_recipes = (
204+
{"sparsity_l1": 1e-5, "prune_threshold": 0.0, "epochs": 4},
205+
{"sparsity_l1": 5e-5, "prune_threshold": 0.0, "epochs": 4},
206+
)
207+
208+
for recipe in sparse_recipes:
209+
sparse_model = _fine_tune_sparse(
210+
model,
211+
data,
212+
config,
213+
sparsity_l1=recipe["sparsity_l1"],
214+
prune_threshold=recipe["prune_threshold"],
215+
epochs=recipe["epochs"],
216+
)
217+
compacted, search = search_compact_dense_model(
218+
sparse_model,
219+
thresholds=thresholds,
220+
scorer=scorer,
221+
min_score=base_acc,
222+
)
223+
if not search.accepted:
224+
continue
225+
test_loss, test_acc = _evaluate(
226+
compacted,
227+
data["x_test"],
228+
data["y_test"],
229+
config.batch_size,
230+
)
231+
candidates.append(
232+
{
233+
"strategy": "sparse_finetune_compaction",
234+
"fine_tune": deepcopy(recipe),
235+
"threshold": search.threshold,
236+
"report": search.report.as_dict() if search.report else None,
237+
"test_loss": test_loss,
238+
"test_accuracy": test_acc,
239+
"hidden_units": [
240+
layer.units
241+
for layer in compacted.layers
242+
if isinstance(layer, keras.layers.Dense)
243+
][:-1],
244+
**_footprint(compacted, data["x_test"]),
245+
}
246+
)
247+
248+
accepted = sorted(
249+
candidates,
250+
key=lambda row: (
251+
row["param_count"],
252+
row["nonzero_count"],
253+
row["keras_file_bytes"],
254+
),
255+
)
256+
selected = accepted[0] if accepted else None
257+
return {
258+
"date": date.today().isoformat(),
259+
"task": "adaptive_neat_lightweight_no_loss",
260+
"optimizer": dict(ADAPTIVE_NEAT_CONFIG),
261+
"base": base_snapshot,
262+
"thresholds": list(thresholds),
263+
"accepted_candidates": accepted,
264+
"selected": selected,
265+
}
266+
267+
268+
def main() -> None:
269+
result = run_adaptive_neat_lightweight_benchmark()
270+
out = Path(
271+
f"benchmarks/results/adaptive_neat_lightweight_{result['date']}.json"
272+
)
273+
out.write_text(json.dumps(result, indent=2))
274+
print(json.dumps(result, indent=2))
275+
276+
277+
if __name__ == "__main__":
278+
main()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
{
2+
"date": "2026-04-05",
3+
"task": "adaptive_neat_lightweight_no_loss",
4+
"optimizer": {
5+
"learning_rate": 0.008,
6+
"alpha": 0.25,
7+
"beta": 0.9,
8+
"opponent_source": "previous_gradient",
9+
"nce_mode": "projection",
10+
"nce_clip_ratio": 1.0,
11+
"adaptive_correction": true,
12+
"adaptive_correction_decay": 0.9,
13+
"adaptive_correction_min_scale": 1.0,
14+
"adaptive_correction_max_scale": 2.5,
15+
"adaptive_preconditioning": true,
16+
"second_moment_beta": 0.999,
17+
"bias_correction": true,
18+
"precondition_nce": true,
19+
"correction_warmup_steps": 0,
20+
"conflict_threshold": 0.0
21+
},
22+
"base": {
23+
"test_loss": 0.1517113596200943,
24+
"test_accuracy": 0.9722222222222222,
25+
"param_count": 17226,
26+
"nonzero_count": 17226,
27+
"keras_file_bytes": 519879,
28+
"mean_inference_seconds": 0.013770256639982109,
29+
"hidden_units": [
30+
128,
31+
64
32+
]
33+
},
34+
"thresholds": [
35+
0.0,
36+
0.02,
37+
0.04,
38+
0.06,
39+
0.08,
40+
0.1,
41+
0.12,
42+
0.14,
43+
0.16,
44+
0.18,
45+
0.2
46+
],
47+
"accepted_candidates": [
48+
{
49+
"strategy": "direct_compaction",
50+
"fine_tune": null,
51+
"threshold": 0.2,
52+
"report": {
53+
"original_hidden_units": [
54+
128,
55+
64
56+
],
57+
"compacted_hidden_units": [
58+
128,
59+
63
60+
],
61+
"original_param_count": 17226,
62+
"compacted_param_count": 17087,
63+
"original_nonzero_count": 17226,
64+
"compacted_nonzero_count": 17087,
65+
"unit_threshold": 0.2
66+
},
67+
"test_loss": 0.1517113596200943,
68+
"test_accuracy": 0.9722222222222222,
69+
"hidden_units": [
70+
128,
71+
63
72+
],
73+
"param_count": 17087,
74+
"nonzero_count": 17087,
75+
"keras_file_bytes": 86707,
76+
"mean_inference_seconds": 0.008188360000203829
77+
}
78+
],
79+
"selected": {
80+
"strategy": "direct_compaction",
81+
"fine_tune": null,
82+
"threshold": 0.2,
83+
"report": {
84+
"original_hidden_units": [
85+
128,
86+
64
87+
],
88+
"compacted_hidden_units": [
89+
128,
90+
63
91+
],
92+
"original_param_count": 17226,
93+
"compacted_param_count": 17087,
94+
"original_nonzero_count": 17226,
95+
"compacted_nonzero_count": 17087,
96+
"unit_threshold": 0.2
97+
},
98+
"test_loss": 0.1517113596200943,
99+
"test_accuracy": 0.9722222222222222,
100+
"hidden_units": [
101+
128,
102+
63
103+
],
104+
"param_count": 17087,
105+
"nonzero_count": 17087,
106+
"keras_file_bytes": 86707,
107+
"mean_inference_seconds": 0.008188360000203829
108+
}
109+
}

0 commit comments

Comments
 (0)