Skip to content

Commit a2173e2

Browse files
Fix reporter generation mismatch and checkpoint lost-work problem
Reporter fix (#273): Moved species detail reporting from end_generation to post_evaluate in StdOutReporter. Species info now matches the evaluated population rather than the post-reproduction next generation. Checkpoint fix (#132/#213): Moved checkpoint saving from end_generation to post_evaluate. Checkpoints now save the evaluated population with fitness values, so restoring never re-runs an expensive evaluation. Added skip-first-evaluation flag to Population so restored runs proceed directly to reproduction. Best genome is now preserved in checkpoints. Old 5-tuple checkpoint format is still loadable for backward compat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c1ae2dd commit a2173e2

File tree

6 files changed

+323
-296
lines changed

6 files changed

+323
-296
lines changed

neat/checkpoint.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@ class Checkpointer(BaseReporter):
1313
"""
1414
A reporter class that performs checkpointing using `pickle`
1515
to save and restore populations (and other aspects of the simulation state).
16+
17+
Checkpoints are saved after fitness evaluation (in ``post_evaluate``), so the
18+
saved population contains genomes with their evaluated fitness values. This
19+
means restoring a checkpoint never re-evaluates work that was already done.
20+
21+
The checkpoint filename suffix (for example, ``neat-checkpoint-10``) refers to
22+
the generation that has just been **evaluated**. Restoring checkpoint ``N``
23+
reproduces from generation ``N``'s evaluated results and then continues
24+
evaluation from generation ``N + 1``.
1625
"""
1726

1827
def __init__(self, generation_interval, time_interval_seconds=None,
1928
filename_prefix='neat-checkpoint-'):
2029
"""
21-
Saves the current state (at the end of a generation) every ``generation_interval`` generations or
22-
``time_interval_seconds``, whichever happens first.
23-
24-
The checkpoint filename suffix (for example, ``neat-checkpoint-10``) always refers to the
25-
**next generation to be evaluated**. In other words, a checkpoint created with suffix ``N``
26-
contains the population and species state for generation ``N`` at the point just before
27-
its fitness evaluation begins.
30+
Saves the current state (after fitness evaluation) every
31+
``generation_interval`` generations or ``time_interval_seconds``,
32+
whichever happens first.
2833
29-
:param generation_interval: If not None, maximum number of generations between save intervals,
30-
measured in generations-to-be-evaluated
34+
:param generation_interval: If not None, maximum number of generations between save intervals
3135
:type generation_interval: int or None
3236
:param time_interval_seconds: If not None, maximum number of seconds between checkpoint attempts
3337
:type time_interval_seconds: float or None
@@ -38,30 +42,18 @@ def __init__(self, generation_interval, time_interval_seconds=None,
3842
self.filename_prefix = filename_prefix
3943

4044
self.current_generation = None
41-
# Tracks the most recent generation index for which a checkpoint was created.
42-
# This value is interpreted as the next generation to be evaluated when the
43-
# checkpoint is restored (see above).
44-
self.last_generation_checkpoint = 0
45+
self.last_generation_checkpoint = -1
4546
self.last_time_checkpoint = time.time()
4647

4748
def start_generation(self, generation):
48-
"""Record the index of the generation that is about to be evaluated.
49-
50-
Note that at the time :meth:`end_generation` is called for generation ``g``,
51-
the population and species that are passed in already correspond to the
52-
*next* generation (``g + 1``). This reporter therefore uses ``g + 1`` as
53-
the generation index stored in checkpoints, so that restoring a
54-
checkpoint labeled ``N`` always resumes at the beginning of generation
55-
``N``.
56-
"""
5749
self.current_generation = generation
5850

59-
def end_generation(self, config, population, species_set):
60-
"""Potentially save a checkpoint at the end of a generation.
51+
def post_evaluate(self, config, population, species, best_genome):
52+
"""Potentially save a checkpoint after fitness evaluation.
6153
62-
The ``population`` and ``species_set`` arguments contain the state for
63-
the next generation to be evaluated, whose index is
64-
``self.current_generation + 1``.
54+
At this point the population has been evaluated and species membership
55+
corresponds to the evaluated genomes, so the checkpoint captures a
56+
fully consistent state with no wasted work on restore.
6557
"""
6658
checkpoint_due = False
6759

@@ -70,71 +62,84 @@ def end_generation(self, config, population, species_set):
7062
if dt >= self.time_interval_seconds:
7163
checkpoint_due = True
7264

73-
# The generation whose population is being saved.
74-
next_generation = self.current_generation + 1
75-
7665
if (not checkpoint_due) and (self.generation_interval is not None):
77-
# Compare the upcoming generation index against the last checkpointed
78-
# generation index to decide whether a new checkpoint is due.
79-
dg = next_generation - self.last_generation_checkpoint
66+
dg = self.current_generation - self.last_generation_checkpoint
8067
if dg >= self.generation_interval:
8168
checkpoint_due = True
8269

8370
if checkpoint_due:
84-
self.save_checkpoint(config, population, species_set, next_generation)
85-
self.last_generation_checkpoint = next_generation
71+
self.save_checkpoint(config, population, species,
72+
self.current_generation, best_genome)
73+
self.last_generation_checkpoint = self.current_generation
8674
self.last_time_checkpoint = time.time()
8775

88-
def save_checkpoint(self, config, population, species_set, generation):
76+
def save_checkpoint(self, config, population, species_set, generation, best_genome=None):
8977
"""
9078
Save the current simulation state.
91-
92-
Note: This is called from Population via the reporter interface.
93-
We need to access the innovation tracker from the Population's reproduction object.
94-
However, since this is a reporter callback, we don't have direct access to Population.
95-
The innovation tracker will be saved as part of the config state when needed.
79+
80+
The saved data includes the evaluated population (with fitness values),
81+
the species set, the generation index, the all-time best genome, and the
82+
random state for reproducibility.
9683
"""
9784
filename = f'{self.filename_prefix}{generation}'
9885
print(f"Saving checkpoint to {filename}")
9986

10087
with gzip.open(filename, 'w', compresslevel=5) as f:
101-
# Note: innovation_tracker is stored in config.genome_config.innovation_tracker
102-
# and is automatically included via pickle
103-
data = (generation, config, population, species_set, random.getstate())
88+
data = (generation, config, population, species_set,
89+
random.getstate(), best_genome)
10490
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
10591

10692
@staticmethod
10793
def restore_checkpoint(filename, new_config=None):
10894
"""
10995
Resumes the simulation from a previous saved point.
110-
111-
The innovation tracker state is preserved in the pickled config and must be
112-
transferred to the new reproduction object to ensure innovation numbers continue
113-
correctly and prevent collisions during crossover.
96+
97+
The checkpoint contains the evaluated population from generation ``N``.
98+
On restore, evaluation is skipped for this generation and the evolution
99+
loop proceeds directly to reproduction, continuing with generation
100+
``N + 1``.
101+
102+
The innovation tracker state is preserved in the pickled config and
103+
transferred to the new reproduction object to ensure innovation numbers
104+
continue correctly.
114105
"""
115106
with gzip.open(filename) as f:
116-
generation, saved_config, population, species_set, rndstate = pickle.load(f)
107+
data = pickle.load(f)
108+
# Support both old (5-tuple) and new (6-tuple) checkpoint formats.
109+
if len(data) == 6:
110+
generation, saved_config, population, species_set, rndstate, best_genome = data
111+
else:
112+
generation, saved_config, population, species_set, rndstate = data
113+
best_genome = None
114+
117115
random.setstate(rndstate)
118-
116+
119117
# Extract the saved innovation tracker from the config before replacing it
120118
saved_innovation_tracker = None
121119
if hasattr(saved_config.genome_config, 'innovation_tracker'):
122120
saved_innovation_tracker = saved_config.genome_config.innovation_tracker
123-
121+
124122
# Use new config if provided, otherwise use saved config
125123
if new_config is not None:
126124
config = new_config
127125
else:
128126
config = saved_config
129-
127+
130128
# Create Population with restored state
131-
# This creates a new reproduction object with a fresh innovation tracker
132129
restored_pop = Population(config, (population, species_set, generation))
133-
130+
131+
# Restore best_genome so the all-time best is not lost
132+
if best_genome is not None:
133+
restored_pop.best_genome = best_genome
134+
135+
# Tell run() to skip the first evaluation — it was already done
136+
# before this checkpoint was saved.
137+
restored_pop._skip_first_evaluation = True
138+
134139
# Replace the fresh innovation tracker with the saved one to maintain
135140
# the correct innovation numbering sequence
136141
if saved_innovation_tracker is not None:
137142
restored_pop.reproduction.innovation_tracker = saved_innovation_tracker
138143
config.genome_config.innovation_tracker = saved_innovation_tracker
139-
144+
140145
return restored_pop

neat/population.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self, config, initial_state=None, seed=None):
6464
self.reproduction.genome_indexer = count(max(self.population.keys()) + 1)
6565

6666
self.best_genome = None
67+
self._skip_first_evaluation = False
6768

6869
def add_reporter(self, reporter):
6970
self.reporters.add(reporter)
@@ -100,29 +101,35 @@ def run(self, fitness_function, n=None):
100101

101102
self.reporters.start_generation(self.generation)
102103

103-
# Evaluate all genomes using the user-provided function.
104-
fitness_function(list(self.population.items()), self.config)
105-
106-
# Gather and report statistics.
107-
best = None
108-
for g in self.population.values():
109-
if g.fitness is None:
110-
raise RuntimeError(f"Fitness not assigned to genome {g.key}")
111-
112-
if best is None or self.config.is_better_fitness(g.fitness, best.fitness):
113-
best = g
114-
self.reporters.post_evaluate(self.config, self.population, self.species, best)
115-
116-
# Track the best genome ever seen.
117-
if self.best_genome is None or self.config.is_better_fitness(best.fitness, self.best_genome.fitness):
118-
self.best_genome = best
119-
120-
if not self.config.no_fitness_termination:
121-
# End if the fitness threshold is reached.
122-
fv = self.fitness_criterion(g.fitness for g in self.population.values())
123-
if self.config.meets_threshold(fv, self.config.fitness_threshold):
124-
self.reporters.found_solution(self.config, self.generation, best)
125-
break
104+
if self._skip_first_evaluation:
105+
# Restored from a checkpoint saved after evaluation.
106+
# The population already has fitness values and reporters
107+
# already saw these results, so skip straight to reproduction.
108+
self._skip_first_evaluation = False
109+
else:
110+
# Evaluate all genomes using the user-provided function.
111+
fitness_function(list(self.population.items()), self.config)
112+
113+
# Gather and report statistics.
114+
best = None
115+
for g in self.population.values():
116+
if g.fitness is None:
117+
raise RuntimeError(f"Fitness not assigned to genome {g.key}")
118+
119+
if best is None or self.config.is_better_fitness(g.fitness, best.fitness):
120+
best = g
121+
self.reporters.post_evaluate(self.config, self.population, self.species, best)
122+
123+
# Track the best genome ever seen.
124+
if self.best_genome is None or self.config.is_better_fitness(best.fitness, self.best_genome.fitness):
125+
self.best_genome = best
126+
127+
if not self.config.no_fitness_termination:
128+
# End if the fitness threshold is reached.
129+
fv = self.fitness_criterion(g.fitness for g in self.population.values())
130+
if self.config.meets_threshold(fv, self.config.fitness_threshold):
131+
self.reporters.found_solution(self.config, self.generation, best)
132+
break
126133

127134
# Create the next generation from the current generation.
128135
self.population = self.reproduction.reproduce(self.config, self.species,

neat/reporting.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,34 +100,34 @@ def start_generation(self, generation):
100100
self.generation_start_time = time.time()
101101

102102
def end_generation(self, config, population, species_set):
103+
elapsed = time.time() - self.generation_start_time
104+
self.generation_times.append(elapsed)
105+
self.generation_times = self.generation_times[-10:]
106+
average = sum(self.generation_times) / len(self.generation_times)
107+
print(f'Total extinctions: {self.num_extinctions:d}')
108+
if len(self.generation_times) > 1:
109+
print(f"Generation time: {elapsed:.3f} sec ({average:.3f} average)")
110+
else:
111+
print(f"Generation time: {elapsed:.3f} sec")
112+
113+
def post_evaluate(self, config, population, species, best_genome):
103114
ng = len(population)
104-
ns = len(species_set.species)
115+
ns = len(species.species)
105116
if self.show_species_detail:
106-
print(f'Population of {ng:d} members in {ns:d} species (after reproduction):')
117+
print(f'Population of {ng:d} members in {ns:d} species:')
107118
print(" ID age size fitness adj fit stag")
108119
print(" ==== === ==== ========= ======= ====")
109-
for sid in sorted(species_set.species):
110-
s = species_set.species[sid]
120+
for sid in sorted(species.species):
121+
s = species.species[sid]
111122
a = self.generation - s.created
112123
n = len(s.members)
113124
f = "--" if s.fitness is None else f"{s.fitness:.3f}"
114125
af = "--" if s.adjusted_fitness is None else f"{s.adjusted_fitness:.3f}"
115126
st = self.generation - s.last_improved
116127
print(f" {sid:>4} {a:>3} {n:>4} {f:>9} {af:>7} {st:>4}")
117128
else:
118-
print(f'Population of {ng:d} members in {ns:d} species (after reproduction)')
129+
print(f'Population of {ng:d} members in {ns:d} species')
119130

120-
elapsed = time.time() - self.generation_start_time
121-
self.generation_times.append(elapsed)
122-
self.generation_times = self.generation_times[-10:]
123-
average = sum(self.generation_times) / len(self.generation_times)
124-
print(f'Total extinctions: {self.num_extinctions:d}')
125-
if len(self.generation_times) > 1:
126-
print(f"Generation time: {elapsed:.3f} sec ({average:.3f} average)")
127-
else:
128-
print(f"Generation time: {elapsed:.3f} sec")
129-
130-
def post_evaluate(self, config, population, species, best_genome):
131131
fitnesses = [c.fitness for c in population.values()]
132132
fit_mean = mean(fitnesses)
133133
fit_std = stdev(fitnesses)

0 commit comments

Comments
 (0)