Skip to content

Commit 41d25c5

Browse files
Clean up GPU backend: remove dead matmul, redundant branches, and duplicate logic
- Remove wasted cp.matmul call in CTRNN loop that computed and discarded result; use pre-allocated s_buf via out= parameter instead - Remove per_genome_inputs if/else branches where both paths were identical (CuPy broadcasting handles both input shapes) - Collapse duplicate input-building if/else in both GPU evaluator classes using np.zeros((num_steps, *first_input.shape)) - Eliminate N = len(genomes) that could diverge from input_fn's actual shape Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 56653d4 commit 41d25c5

File tree

2 files changed

+21
-59
lines changed

2 files changed

+21
-59
lines changed

neat/gpu/_cupy_backend.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,8 @@ def evaluate_ctrnn_batch(packed, inputs_cpu, dt):
131131
scale[:, :num_inputs] = 0.0
132132

133133
# Transfer inputs to GPU.
134-
if inputs_cpu.ndim == 2:
135-
# [num_steps, num_inputs] → broadcast to all genomes
136-
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
137-
per_genome_inputs = False
138-
else:
139-
# [num_steps, N, num_inputs]
140-
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
141-
per_genome_inputs = True
134+
# Shape is either [num_steps, num_inputs] (broadcast) or [num_steps, N, num_inputs].
135+
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
142136

143137
# Initialize state.
144138
u = cp.zeros((N, M), dtype=cp.float32)
@@ -165,15 +159,11 @@ def evaluate_ctrnn_batch(packed, inputs_cpu, dt):
165159

166160
for step in range(num_steps):
167161
# Step 1: Set input node states.
168-
if per_genome_inputs:
169-
u[:, :num_inputs] = inputs_gpu[step] # [N, num_inputs]
170-
else:
171-
u[:, :num_inputs] = inputs_gpu[step] # [num_inputs] broadcast
162+
u[:, :num_inputs] = inputs_gpu[step]
172163

173164
# Step 2: Batched matrix-vector multiply.
174165
# s = W @ u → [N, M] (treat u as [N, M, 1], squeeze result)
175-
cp.matmul(W, u[:, :, None], out=s_buf[:, :, None] if False else None)
176-
s_buf = cp.matmul(W, u[:, :, None]).squeeze(-1) # [N, M]
166+
cp.matmul(W, u[:, :, None], out=s_buf[:, :, None])
177167

178168
# Step 3: Apply activation function via custom kernel.
179169
s_flat = s_buf.ravel()
@@ -187,10 +177,7 @@ def evaluate_ctrnn_batch(packed, inputs_cpu, dt):
187177
u += scale * z_buf
188178

189179
# Step 5: Re-clamp input nodes.
190-
if per_genome_inputs:
191-
u[:, :num_inputs] = inputs_gpu[step]
192-
else:
193-
u[:, :num_inputs] = inputs_gpu[step]
180+
u[:, :num_inputs] = inputs_gpu[step]
194181

195182
# Step 6: Record output node states.
196183
trajectory[:, step, :] = u[:, out_start:out_end]
@@ -235,13 +222,9 @@ def evaluate_iznn_batch(packed, inputs_cpu, dt, num_steps):
235222
d = cp.asarray(packed['d']) # [N, M]
236223
node_mask = cp.asarray(packed['node_mask']) # [N, M]
237224

238-
# Transfer inputs.
239-
if inputs_cpu.ndim == 2:
240-
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
241-
per_genome_inputs = False
242-
else:
243-
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
244-
per_genome_inputs = True
225+
# Transfer inputs to GPU.
226+
# Shape is either [num_steps, num_inputs] (broadcast) or [num_steps, N, num_inputs].
227+
inputs_gpu = cp.asarray(inputs_cpu.astype(np.float32))
245228

246229
# Initialize state: v = c, u_recov = b * v, fired = 0.
247230
v = c.copy()
@@ -263,10 +246,7 @@ def evaluate_iznn_batch(packed, inputs_cpu, dt, num_steps):
263246
for step in range(num_steps):
264247
# Build source vector: fired for neuron slots, external input for input slots.
265248
source[:] = fired
266-
if per_genome_inputs:
267-
source[:, :num_inputs] = inputs_gpu[step]
268-
else:
269-
source[:, :num_inputs] = inputs_gpu[step]
249+
source[:, :num_inputs] = inputs_gpu[step]
270250

271251
# Compute synaptic current: I = bias + W @ source
272252
I = bias + cp.matmul(W, source[:, :, None]).squeeze(-1)

neat/gpu/evaluator.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,13 @@ def evaluate(self, genomes, config):
6767
num_steps = int(self.t_max / self.dt)
6868

6969
# Precompute input trajectory on CPU.
70+
# first_input shape is either [num_inputs] or [N, num_inputs].
7071
first_input = np.asarray(self.input_fn(0.0, self.dt), dtype=np.float32)
71-
if first_input.ndim == 1:
72-
# Scalar input: [num_inputs]
73-
inputs = np.zeros((num_steps, len(first_input)), dtype=np.float32)
74-
inputs[0] = first_input
75-
for step in range(1, num_steps):
76-
inputs[step] = np.asarray(
77-
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
78-
else:
79-
# Per-genome input: [N, num_inputs]
80-
N = len(genomes)
81-
num_inputs = first_input.shape[-1]
82-
inputs = np.zeros((num_steps, N, num_inputs), dtype=np.float32)
83-
inputs[0] = first_input
84-
for step in range(1, num_steps):
85-
inputs[step] = np.asarray(
86-
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
72+
inputs = np.zeros((num_steps, *first_input.shape), dtype=np.float32)
73+
inputs[0] = first_input
74+
for step in range(1, num_steps):
75+
inputs[step] = np.asarray(
76+
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
8777

8878
# Pack genomes into padded arrays.
8979
packed = pack_ctrnn_population(genomes, config)
@@ -131,21 +121,13 @@ def evaluate(self, genomes, config):
131121
num_steps = int(self.t_max / self.dt)
132122

133123
# Precompute input trajectory on CPU.
124+
# first_input shape is either [num_inputs] or [N, num_inputs].
134125
first_input = np.asarray(self.input_fn(0.0, self.dt), dtype=np.float32)
135-
if first_input.ndim == 1:
136-
inputs = np.zeros((num_steps, len(first_input)), dtype=np.float32)
137-
inputs[0] = first_input
138-
for step in range(1, num_steps):
139-
inputs[step] = np.asarray(
140-
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
141-
else:
142-
N = len(genomes)
143-
num_inputs = first_input.shape[-1]
144-
inputs = np.zeros((num_steps, N, num_inputs), dtype=np.float32)
145-
inputs[0] = first_input
146-
for step in range(1, num_steps):
147-
inputs[step] = np.asarray(
148-
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
126+
inputs = np.zeros((num_steps, *first_input.shape), dtype=np.float32)
127+
inputs[0] = first_input
128+
for step in range(1, num_steps):
129+
inputs[step] = np.asarray(
130+
self.input_fn(step * self.dt, self.dt), dtype=np.float32)
149131

150132
packed = pack_iznn_population(genomes, config)
151133
trajectory = evaluate_iznn_batch(packed, inputs, self.dt, num_steps)

0 commit comments

Comments
 (0)