Skip to content

Commit 4957a9d

Browse files
Fix mutate_add_node innovation deduplication (issue #291)
The dedup key included the newly-allocated node ID, which is always unique, so the tracker lookup never matched across genomes. Key by the original connection endpoints instead via a new get_node_split() method that also shares the node ID across genomes splitting the same connection in one generation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6811a56 commit 4957a9d

File tree

4 files changed

+148
-35
lines changed

4 files changed

+148
-35
lines changed

neat/genome.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,24 @@ def mutate_add_node(self, config):
403403

404404
# Choose a random connection to split
405405
conn_to_split = choice(list(self.connections.values()))
406-
new_node_id = config.get_new_node_key(self.nodes)
406+
407+
i, o = conn_to_split.key
408+
409+
# Get the node ID and innovation numbers for this split from the tracker.
410+
# If another genome already split this same connection in this generation,
411+
# we get back the same node ID and innovation numbers — this is a core NEAT
412+
# requirement for proper crossover alignment (Stanley & Miikkulainen, 2002).
413+
new_node_id, in_innovation, out_innovation = (
414+
config.innovation_tracker.get_node_split(
415+
i, o, lambda: config.get_new_node_key(self.nodes)
416+
)
417+
)
418+
419+
# If this genome already has this node (e.g. from crossover with a genome
420+
# that previously split this same connection), skip the mutation.
421+
if new_node_id in self.nodes:
422+
return
423+
407424
ng = self.create_node(config, new_node_id)
408425

409426
# Make the new node as neutral as possible with respect to the
@@ -419,18 +436,6 @@ def mutate_add_node(self, config):
419436
# the original connection (depending on the activation function of the new node).
420437
conn_to_split.enabled = False
421438

422-
i, o = conn_to_split.key
423-
424-
# Get innovation numbers for the two new connections
425-
# These are keyed by the connection being split, so multiple genomes splitting
426-
# the same connection get matching innovation numbers
427-
in_innovation = config.innovation_tracker.get_innovation_number(
428-
i, new_node_id, 'add_node_in'
429-
)
430-
out_innovation = config.innovation_tracker.get_innovation_number(
431-
new_node_id, o, 'add_node_out'
432-
)
433-
434439
# Add the two new connections with their innovation numbers
435440
self.add_connection(config, i, new_node_id, 1.0, True, innovation=in_innovation)
436441
self.add_connection(config, new_node_id, o, conn_to_split.weight, True, innovation=out_innovation)

neat/innovation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,44 @@ def get_innovation_number(self, input_node, output_node, mutation_type='add_conn
9696

9797
return innovation_number
9898

99+
def get_node_split(self, in_node, out_node, allocate_node_key):
100+
"""
101+
Get or assign innovation numbers and a node ID for splitting a connection.
102+
103+
If this same connection (in_node -> out_node) has already been split in the
104+
current generation, returns the previously assigned (node_id, in_innovation,
105+
out_innovation). Otherwise, calls allocate_node_key() to obtain a new node ID
106+
and assigns two new innovation numbers.
107+
108+
This ensures that when multiple genomes independently split the same connection
109+
in one generation, they all receive the same node ID and matching innovation
110+
numbers — a core requirement of the NEAT algorithm for proper crossover alignment.
111+
112+
Args:
113+
in_node: The input node of the connection being split
114+
out_node: The output node of the connection being split
115+
allocate_node_key: A callable that returns a new unique node ID.
116+
Only called the first time this split is seen in
117+
the current generation.
118+
119+
Returns:
120+
tuple: (node_id, in_innovation, out_innovation)
121+
"""
122+
key = (in_node, out_node, 'split_node')
123+
124+
if key in self.generation_innovations:
125+
return self.generation_innovations[key]
126+
127+
node_id = allocate_node_key()
128+
self.global_counter += 1
129+
in_innovation = self.global_counter
130+
self.global_counter += 1
131+
out_innovation = self.global_counter
132+
result = (node_id, in_innovation, out_innovation)
133+
self.generation_innovations[key] = result
134+
135+
return result
136+
99137
def reset_generation(self):
100138
"""
101139
Clear generation-specific tracking at the start of a new generation.

tests/test_genome_mutations.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,20 +402,23 @@ def test_add_connection_with_innovation_tracking(self):
402402
def test_remove_node_deletes_node(self):
403403
"""
404404
Test that remove_node mutation deletes a node.
405-
405+
406406
Should remove exactly one node from the genome.
407407
"""
408408
genome = self.create_minimal_genome()
409-
410-
# Add some nodes first
409+
410+
# Add nodes, also adding connections between splits so that
411+
# each split targets a distinct connection (avoiding the
412+
# within-generation dedup in the innovation tracker).
411413
for _ in range(3):
412414
if genome.connections:
413415
genome.mutate_add_node(self.config.genome_config)
414-
416+
genome.mutate_add_connection(self.config.genome_config)
417+
415418
initial_count = len(genome.nodes)
416419
if initial_count > len(self.config.genome_config.output_keys):
417420
genome.mutate_delete_node(self.config.genome_config)
418-
421+
419422
self.assertEqual(len(genome.nodes), initial_count - 1,
420423
"Should remove one node")
421424

tests/test_innovation.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,24 +100,91 @@ def test_reset_generation_preserves_global_counter(self):
100100
inn = tracker.get_innovation_number(2, 3, 'add_connection')
101101
self.assertEqual(inn, 4)
102102

103-
def test_add_node_mutations_tracked_separately(self):
104-
"""Test that add_node mutations track both connections separately."""
103+
def test_node_split_deduplication(self):
104+
"""Test that splitting the same connection in one generation produces
105+
the same node ID and innovation numbers, regardless of which genome
106+
performs the split."""
105107
tracker = neat.InnovationTracker()
106-
107-
# Split connection 0->1 by adding node 2
108-
inn_in = tracker.get_innovation_number(0, 2, 'add_node_in')
109-
inn_out = tracker.get_innovation_number(2, 1, 'add_node_out')
110-
111-
self.assertEqual(inn_in, 1)
112-
self.assertEqual(inn_out, 2)
113-
114-
# Another genome splits same connection
115-
inn_in2 = tracker.get_innovation_number(0, 2, 'add_node_in')
116-
inn_out2 = tracker.get_innovation_number(2, 1, 'add_node_out')
117-
118-
# Should get same innovation numbers
119-
self.assertEqual(inn_in2, inn_in)
120-
self.assertEqual(inn_out2, inn_out)
108+
next_node = iter(range(10, 100))
109+
110+
# Genome A splits connection 0->1
111+
node_a, inn_in_a, inn_out_a = tracker.get_node_split(
112+
0, 1, lambda: next(next_node)
113+
)
114+
print(f" Genome A: node={node_a}, in_inn={inn_in_a}, out_inn={inn_out_a}")
115+
116+
self.assertEqual(node_a, 10) # first call allocates from iterator
117+
self.assertEqual(inn_in_a, 1)
118+
self.assertEqual(inn_out_a, 2)
119+
120+
# Genome B splits the SAME connection 0->1 in the same generation
121+
node_b, inn_in_b, inn_out_b = tracker.get_node_split(
122+
0, 1, lambda: next(next_node)
123+
)
124+
print(f" Genome B: node={node_b}, in_inn={inn_in_b}, out_inn={inn_out_b}")
125+
126+
# Must get the same node ID and innovation numbers
127+
self.assertEqual(node_b, node_a)
128+
self.assertEqual(inn_in_b, inn_in_a)
129+
self.assertEqual(inn_out_b, inn_out_a)
130+
131+
def test_node_split_different_connections_get_different_ids(self):
132+
"""Test that splitting different connections produces different node IDs
133+
and innovation numbers."""
134+
tracker = neat.InnovationTracker()
135+
next_node = iter(range(10, 100))
136+
137+
node_a, inn_in_a, inn_out_a = tracker.get_node_split(
138+
0, 1, lambda: next(next_node)
139+
)
140+
node_b, inn_in_b, inn_out_b = tracker.get_node_split(
141+
2, 3, lambda: next(next_node)
142+
)
143+
print(f" Split (0->1): node={node_a}, innovations=({inn_in_a}, {inn_out_a})")
144+
print(f" Split (2->3): node={node_b}, innovations=({inn_in_b}, {inn_out_b})")
145+
146+
self.assertNotEqual(node_a, node_b)
147+
self.assertNotEqual(inn_in_a, inn_in_b)
148+
self.assertNotEqual(inn_out_a, inn_out_b)
149+
150+
def test_node_split_allocator_called_only_once(self):
151+
"""Test that the allocate_node_key callable is only called the first time
152+
a connection is split in a generation (not on dedup hits)."""
153+
tracker = neat.InnovationTracker()
154+
call_count = [0]
155+
156+
def allocator():
157+
call_count[0] += 1
158+
return 42
159+
160+
tracker.get_node_split(0, 1, allocator)
161+
self.assertEqual(call_count[0], 1)
162+
163+
# Second call for same split should NOT call allocator
164+
tracker.get_node_split(0, 1, allocator)
165+
self.assertEqual(call_count[0], 1)
166+
167+
def test_node_split_reset_between_generations(self):
168+
"""Test that the same connection split in different generations gets
169+
different node IDs and innovation numbers."""
170+
tracker = neat.InnovationTracker()
171+
next_node = iter(range(10, 100))
172+
173+
node_g1, inn_in_g1, inn_out_g1 = tracker.get_node_split(
174+
0, 1, lambda: next(next_node)
175+
)
176+
177+
tracker.reset_generation()
178+
179+
node_g2, inn_in_g2, inn_out_g2 = tracker.get_node_split(
180+
0, 1, lambda: next(next_node)
181+
)
182+
print(f" Gen 1: node={node_g1}, innovations=({inn_in_g1}, {inn_out_g1})")
183+
print(f" Gen 2: node={node_g2}, innovations=({inn_in_g2}, {inn_out_g2})")
184+
185+
self.assertNotEqual(node_g1, node_g2)
186+
self.assertNotEqual(inn_in_g1, inn_in_g2)
187+
self.assertNotEqual(inn_out_g1, inn_out_g2)
121188

122189
def test_pickle_serialization(self):
123190
"""Test that InnovationTracker can be pickled and unpickled."""

0 commit comments

Comments
 (0)