Skip to content

Commit 6ca4282

Browse files
committed
MORE FASTER
1 parent 7673c98 commit 6ca4282

2 files changed

Lines changed: 111 additions & 74 deletions

File tree

  • apps/typegpu-docs/src/examples/algorithms/genetic-racing

apps/typegpu-docs/src/examples/algorithms/genetic-racing/ga.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ export const SimParams = d.struct({
6060
spawnX: d.f32,
6161
spawnY: d.f32,
6262
spawnAngle: d.f32,
63+
stepsPerDispatch: d.u32,
6364
});
6465

6566
export const CarStateArray = d.arrayOf(CarState, MAX_POP);

apps/typegpu-docs/src/examples/algorithms/genetic-racing/index.ts

Lines changed: 110 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ const canvas = document.querySelector('canvas') as HTMLCanvasElement;
2323
const context = root.configureContext({ canvas, alphaMode: 'premultiplied' });
2424
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
2525

26+
const STEPS_PER_DISPATCH = 32;
27+
2628
const BASE_SPATIAL_PARAMS = {
2729
maxSpeed: 1.6,
2830
accel: 0.2,
@@ -44,6 +46,7 @@ const params = root.createUniform(SimParams, {
4446
spawnX: 0,
4547
spawnY: 0,
4648
spawnAngle: 0,
49+
stepsPerDispatch: STEPS_PER_DISPATCH,
4750
...BASE_SPATIAL_PARAMS,
4851
});
4952

@@ -157,76 +160,98 @@ const simulatePipeline = root.createGuardedComputePipeline((i) => {
157160
if (d.u32(i) >= params.$.population) {
158161
return;
159162
}
160-
const car = CarState(simLayout.$.state[i]);
161-
const genome = Genome(simLayout.$.genome[i]);
162-
const wasAlive = car.alive === 1;
163163

164-
const carForward = d.vec2f(std.cos(car.angle), std.sin(car.angle));
165-
const aheadPos = car.position + carForward * params.$.sensorDistance;
166-
167-
const inputs4 = d.vec4f(
168-
senseRaycast(car.position, car.angle, DEG_60),
169-
senseRaycast(car.position, car.angle, DEG_30),
170-
senseRaycast(car.position, car.angle, 0),
171-
senseRaycast(car.position, car.angle, -DEG_30),
172-
);
173-
const inputsB = d.vec4f(
174-
senseRaycast(car.position, car.angle, -DEG_60),
175-
car.speed / params.$.maxSpeed,
176-
std.dot(carForward, sampleTrack(car.position, nearestSampler.$).xy),
177-
trackCross(carForward, aheadPos),
178-
);
179-
const inputsC = d.vec4f(
180-
car.angVel / params.$.turnRate,
181-
senseRaycast(car.position, car.angle, DEG_90),
182-
senseRaycast(car.position, car.angle, -DEG_90),
183-
trackCross(carForward, car.position + carForward * params.$.sensorDistance * 2),
184-
);
164+
const genome = Genome(simLayout.$.genome[i]);
165+
const initCar = CarState(simLayout.$.state[i]);
166+
167+
let curPosition = d.vec2f(initCar.position);
168+
let curAngle = initCar.angle;
169+
let curSpeed = initCar.speed;
170+
let curAlive = initCar.alive;
171+
let curProgress = initCar.progress;
172+
let curAngVel = initCar.angVel;
173+
let curAliveSteps = initCar.aliveSteps;
174+
let curStallSteps = initCar.stallSteps;
175+
176+
for (let s = d.u32(0); s < params.$.stepsPerDispatch; s++) {
177+
if (curAlive === 0) {
178+
break;
179+
}
185180

186-
const control = evalNetwork(genome, inputs4, inputsB, inputsC);
187-
const steer = control.x;
188-
const throttle = control.y;
189-
190-
let speed = car.speed + throttle * params.$.accel * params.$.dt;
191-
speed = speed * (1 - params.$.drag * speed * params.$.dt);
192-
speed = std.clamp(speed, 0, params.$.maxSpeed);
193-
194-
const slowThreshold = params.$.maxSpeed * 0.04;
195-
const canTurn = speed > slowThreshold;
196-
const normSpeed = speed / params.$.maxSpeed;
197-
const turnFactor = (1 - normSpeed) * (1 - normSpeed);
198-
const targetAngVel = std.select(0, steer * params.$.turnRate * turnFactor, canTurn);
199-
const angVel = car.angVel * 0.75 + targetAngVel * 0.25;
200-
const angle = car.angle + angVel * params.$.dt;
201-
202-
const dir = d.vec2f(std.cos(angle), std.sin(angle));
203-
const position = car.position + dir * speed * params.$.dt;
204-
const step = position - car.position;
205-
206-
const stallSteps = std.select(d.u32(0), car.stallSteps + 1, speed < slowThreshold);
207-
const trackEnd = sampleTrack(position, nearestSampler.$);
208-
const onTrack =
209-
wasAlive &&
210-
stallSteps < 120 &&
211-
trackEnd.z > 0.5 &&
212-
isOnTrack(car.position + step * 0.33) &&
213-
isOnTrack(car.position + step * 0.66);
214-
215-
const alive = std.select(d.u32(0), d.u32(1), onTrack);
216-
const forward = std.dot(dir, trackEnd.xy);
217-
const lapLength = params.$.trackLength * params.$.trackScale;
218-
const progress =
219-
car.progress + (speed * std.max(0, forward) * params.$.dt * d.f32(alive)) / lapLength;
181+
const carForward = d.vec2f(std.cos(curAngle), std.sin(curAngle));
182+
const aheadPos = curPosition + carForward * params.$.sensorDistance;
183+
184+
const inputs4 = d.vec4f(
185+
senseRaycast(curPosition, curAngle, DEG_60),
186+
senseRaycast(curPosition, curAngle, DEG_30),
187+
senseRaycast(curPosition, curAngle, 0),
188+
senseRaycast(curPosition, curAngle, -DEG_30),
189+
);
190+
const inputsB = d.vec4f(
191+
senseRaycast(curPosition, curAngle, -DEG_60),
192+
curSpeed / params.$.maxSpeed,
193+
std.dot(carForward, sampleTrack(curPosition, nearestSampler.$).xy),
194+
trackCross(carForward, aheadPos),
195+
);
196+
const inputsC = d.vec4f(
197+
curAngVel / params.$.turnRate,
198+
senseRaycast(curPosition, curAngle, DEG_90),
199+
senseRaycast(curPosition, curAngle, -DEG_90),
200+
trackCross(carForward, curPosition + carForward * params.$.sensorDistance * 2),
201+
);
202+
203+
const control = evalNetwork(genome, inputs4, inputsB, inputsC);
204+
const steer = control.x;
205+
const throttle = control.y;
206+
207+
let speed = curSpeed + throttle * params.$.accel * params.$.dt;
208+
speed = speed * (1 - params.$.drag * speed * params.$.dt);
209+
speed = std.clamp(speed, 0, params.$.maxSpeed);
210+
211+
const slowThreshold = params.$.maxSpeed * 0.04;
212+
const canTurn = speed > slowThreshold;
213+
const normSpeed = speed / params.$.maxSpeed;
214+
const turnFactor = (1 - normSpeed) * (1 - normSpeed);
215+
const targetAngVel = std.select(0, steer * params.$.turnRate * turnFactor, canTurn);
216+
const angVel = curAngVel * 0.75 + targetAngVel * 0.25;
217+
const angle = curAngle + angVel * params.$.dt;
218+
219+
const dir = d.vec2f(std.cos(angle), std.sin(angle));
220+
const position = curPosition + dir * speed * params.$.dt;
221+
const stepVec = position - curPosition;
222+
223+
const stallSteps = std.select(d.u32(0), curStallSteps + 1, speed < slowThreshold);
224+
const trackEnd = sampleTrack(position, nearestSampler.$);
225+
const onTrack =
226+
stallSteps < 120 &&
227+
trackEnd.z > 0.5 &&
228+
isOnTrack(curPosition + stepVec * 0.33) &&
229+
isOnTrack(curPosition + stepVec * 0.66);
230+
231+
const alive = std.select(d.u32(0), d.u32(1), onTrack);
232+
const forward = std.dot(dir, trackEnd.xy);
233+
const lapLength = params.$.trackLength * params.$.trackScale;
234+
235+
curPosition = std.select(curPosition, position, onTrack);
236+
curAngle = std.select(curAngle, angle, onTrack);
237+
curSpeed = std.select(0, speed, onTrack);
238+
curAlive = alive;
239+
curProgress =
240+
curProgress + (speed * std.max(0, forward) * params.$.dt * d.f32(alive)) / lapLength;
241+
curAngVel = std.select(0, angVel, onTrack);
242+
curAliveSteps = curAliveSteps + 1;
243+
curStallSteps = stallSteps;
244+
}
220245

221246
simLayout.$.state[i] = CarState({
222-
position: std.select(car.position, position, onTrack),
223-
angle: std.select(car.angle, angle, onTrack),
224-
speed: std.select(0, speed, onTrack),
225-
alive,
226-
progress,
227-
angVel: std.select(0, angVel, onTrack),
228-
aliveSteps: car.aliveSteps + std.select(d.u32(0), d.u32(1), wasAlive),
229-
stallSteps,
247+
position: curPosition,
248+
angle: curAngle,
249+
speed: curSpeed,
250+
alive: curAlive,
251+
progress: curProgress,
252+
angVel: curAngVel,
253+
aliveSteps: curAliveSteps,
254+
stallSteps: curStallSteps,
230255
});
231256
});
232257

@@ -363,7 +388,7 @@ const carPipeline = root.createRenderPipeline({
363388
});
364389

365390
let steps = 0;
366-
let stepsPerFrame = 1;
391+
let stepsPerFrame = STEPS_PER_DISPATCH;
367392
let stepsPerGeneration = 2048;
368393
let paused = false;
369394
let lastAspect = 1;
@@ -416,18 +441,29 @@ function frame() {
416441
}
417442

418443
const stepsToRun = Math.min(stepsPerFrame, stepsPerGeneration - steps);
419-
for (let i = 0; i < stepsToRun; i++) {
420-
simulatePipeline.with(simBindGroups[ga.current]).dispatchThreads(population);
444+
const innerSteps = Math.min(stepsToRun, STEPS_PER_DISPATCH);
445+
params.writePartial({ stepsPerDispatch: innerSteps });
446+
const dispatchCount = Math.ceil(stepsToRun / innerSteps);
447+
448+
const simEncoder = root.device.createCommandEncoder();
449+
for (let dispatch = 0; dispatch < dispatchCount; dispatch++) {
450+
simulatePipeline.with(simBindGroups[ga.current]).with(simEncoder).dispatchThreads(population);
421451
}
422-
steps += stepsToRun;
452+
root.device.queue.submit([simEncoder.finish()]);
453+
454+
steps += dispatchCount * innerSteps;
423455

424456
if (steps >= stepsPerGeneration) {
425457
pendingEvolve = true;
426458
ga.precomputeFitness(population);
427-
reductionPackedBuffer.clear();
428459
const bg = reductionBindGroups[ga.current];
429-
reductionPipeline.with(bg).dispatchThreads(population);
430-
finalizeReductionPipeline.with(bg).dispatchThreads(1);
460+
461+
const reductionEncoder = root.device.createCommandEncoder();
462+
reductionEncoder.clearBuffer(root.unwrap(reductionPackedBuffer));
463+
reductionPipeline.with(bg).with(reductionEncoder).dispatchThreads(population);
464+
finalizeReductionPipeline.with(bg).with(reductionEncoder).dispatchThreads(1);
465+
root.device.queue.submit([reductionEncoder.finish()]);
466+
431467
hasChampion = true;
432468
void bestFitnessBuffer.read().then((fitness) => {
433469
displayedBestFitness = fitness;
@@ -534,7 +570,7 @@ export const controls = defineControls({
534570
'Steps per frame': {
535571
initial: stepsPerFrame,
536572
min: 1,
537-
max: 1024,
573+
max: 8192,
538574
step: 1,
539575
onSliderChange: (value: number) => {
540576
stepsPerFrame = value;

0 commit comments

Comments
 (0)