Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/strands/strands_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
const nodeData = DAG.createNodeData({
nodeType: NodeType.STATEMENT,
statementType: StatementType.EARLY_RETURN,
dependsOn: [valueNode.id]
dependsOn: value !== undefined ? [valueNode.id] : []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind elaborating on what these changes are there to handle? Anything we should have more test cases for in the tests?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fix void return types in compute shaders. Without them, doing return; in a compute hook would crash with "Missing dataType". Most compute shaders use void (side-effects only), so the auto-spread wouldn't work without this fix.

For tests - should I add cases for void hooks with early returns? The main compute functionality already has test coverage.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it. Right, let's add a test for early returns, since this wasn't a case covered by any tests before. Thanks!

Copy link
Copy Markdown
Author

@aashu2006 aashu2006 Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the test cases for void compute hooks with early returns. Both tests are passing.
Thanks!

});
const earlyReturnID = DAG.getOrCreateNode(dag, nodeData);
CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID);
Expand Down Expand Up @@ -786,17 +786,21 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
return newStruct.id;
}
}
else if (!expectedReturnType.dataType || expectedReturnType.typeName?.trim() === 'void') {
return null;
}
else /*if(isNativeType(expectedReturnType.typeName))*/ {
if (!expectedReturnType.dataType) {
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
}
const expectedTypeInfo = expectedReturnType.dataType;
return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name);
}
}
for (const { valueNode, earlyReturnID } of hook.earlyReturns) {
const id = handleRetVal(valueNode);
dag.dependsOn[earlyReturnID] = [id];
if (id !== null) {
dag.dependsOn[earlyReturnID] = [id];
} else {
dag.dependsOn[earlyReturnID] = [];
}
}
rootNodeID = userReturned ? handleRetVal(userReturned) : undefined;
const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`;
Expand Down
5 changes: 1 addition & 4 deletions src/strands/strands_codegen.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,9 @@ export function generateShaderCode(strandsContext) {
let returnType;
if (hookType.returnType.properties) {
returnType = structType(hookType.returnType);
} else if (hookType.returnType.typeName === 'void') {
} else if (!hookType.returnType.dataType || hookType.returnType.typeName?.trim() === 'void') {
returnType = null;
} else {
if (!hookType.returnType.dataType) {
throw new Error(`Missing dataType for return type ${hookType.returnType.typeName}`);
}
returnType = hookType.returnType.dataType;
}

Expand Down
43 changes: 39 additions & 4 deletions src/webgpu/p5.RendererWebGPU.js
Original file line number Diff line number Diff line change
Expand Up @@ -3813,10 +3813,45 @@ ${hookUniformFields}}
const WORKGROUP_SIZE_Y = 8;
const WORKGROUP_SIZE_Z = 1;

// Calculate number of workgroups needed
const workgroupCountX = Math.ceil(x / WORKGROUP_SIZE_X);
const workgroupCountY = Math.ceil(y / WORKGROUP_SIZE_Y);
const workgroupCountZ = Math.ceil(z / WORKGROUP_SIZE_Z);
// auto spreading: if any dimension is too large or for performance optimization,
// spread total iteration count across dimensions
const totalIterations = x * y * z;
const MAX_THREADS_PER_DIM = 65535 * 8;

let px = x;
let py = y;
let pz = z;

// we spread if we exceed GPU limits OR if it involves a large 1D dispatch
const exceedsLimits = x > MAX_THREADS_PER_DIM || y > MAX_THREADS_PER_DIM || z > MAX_THREADS_PER_DIM;
const isLarge1D = totalIterations > 1024 && y === 1 && z === 1;

if (exceedsLimits || isLarge1D) {
if (totalIterations > 1000000) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity is there any benefit to spreading across dimensions like this for lower iteration counts too? e.g. if you're doing a big for loop inside of each iteration, with a smaller number of iterations, is there any difference?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! Currently I only auto-spread when count > 1024 to avoid overhead for small dispatches. For lower counts with heavy per iteration work, manual spreading might still help but I kept it simple for now. We could test this if you think it's worth optimizing?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth testing at least to know what kind of difference it makes, and similarly if it's better to spread across 3 dimensions earlier too. A sort of table of performance tests would help us just be a bit more confident about our optimizations.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can run some quick tests comparing different spreading approaches across small, medium, and large counts. I’ll check 1D, 2D (square/rectangular), and 3D, and share a simple performance table with the results. Should be interesting to see where things start to slow down.

let me know if there’s anything specific you’d like me to test, or if you want to try something on your machine as well 👍

// 3D cube type for extreme large counts
px = Math.ceil(Math.pow(totalIterations, 1 / 3));
py = Math.ceil(Math.pow(totalIterations, 1 / 3));
pz = Math.ceil(totalIterations / (px * py));
} else {
// 2D square type for moderate large counts
px = Math.ceil(Math.sqrt(totalIterations));
py = Math.ceil(totalIterations / px);
pz = 1;
}

if (p5.debug || exceedsLimits) {
console.warn(
`p5.js: Compute dispatch (${x}, ${y}, ${z}) auto-spread to (${px}, ${py}, ${pz}) ` +
`to ${exceedsLimits ? 'stay within GPU limits' : 'optimize performance'}.`
);
}
}

shader.setUniform('uPhysicalCount', [px, py, pz]);

const workgroupCountX = Math.ceil(px / WORKGROUP_SIZE_X);
const workgroupCountY = Math.ceil(py / WORKGROUP_SIZE_Y);
const workgroupCountZ = Math.ceil(pz / WORKGROUP_SIZE_Z);

const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
Expand Down
16 changes: 10 additions & 6 deletions src/webgpu/shaders/compute.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export const baseComputeShader = `
struct ComputeUniforms {
uTotalCount: vec3<i32>,
uPhysicalCount: vec3<i32>,
}
@group(0) @binding(0) var<uniform> uniforms: ComputeUniforms;

Expand All @@ -11,16 +12,19 @@ fn main(
@builtin(workgroup_id) workgroupId: vec3<u32>,
@builtin(local_invocation_index) localIndex: u32
) {
var index = vec3<i32>(globalId);
let totalIterations = u32(uniforms.uTotalCount.x) * u32(uniforms.uTotalCount.y) * u32(uniforms.uTotalCount.z);
let physicalId = globalId.x + globalId.y * (u32(uniforms.uPhysicalCount.x)) + globalId.z * (u32(uniforms.uPhysicalCount.x) * u32(uniforms.uPhysicalCount.y));

if (
index.x >= uniforms.uTotalCount.x ||
index.y >= uniforms.uTotalCount.y ||
index.z >= uniforms.uTotalCount.z
) {
if (physicalId >= totalIterations) {
return;
}

var index = vec3<i32>(0);
index.x = i32(physicalId % u32(uniforms.uTotalCount.x));
let remainingY = physicalId / u32(uniforms.uTotalCount.x);
index.y = i32(remainingY % u32(uniforms.uTotalCount.y));
index.z = i32(remainingY / u32(uniforms.uTotalCount.y));

HOOK_iteration(index);
}
`;
10 changes: 7 additions & 3 deletions src/webgpu/strands_wgslBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,13 @@ export const wgslBackend = {
// Generate just a semicolon (unless suppressed)
generationContext.write(semicolon);
} else if (node.statementType === StatementType.EARLY_RETURN) {
const exprNodeID = node.dependsOn[0];
const expr = this.generateExpression(generationContext, dag, exprNodeID);
generationContext.write(`return ${expr}${semicolon}`);
if (node.dependsOn && node.dependsOn.length > 0) {
const exprNodeID = node.dependsOn[0];
const expr = this.generateExpression(generationContext, dag, exprNodeID);
generationContext.write(`return ${expr}${semicolon}`);
} else {
generationContext.write(`return${semicolon}`);
}
}
},
generateAssignment(generationContext, dag, nodeID) {
Expand Down
41 changes: 41 additions & 0 deletions test/unit/webgpu/p5.Shader.js
Original file line number Diff line number Diff line change
Expand Up @@ -1228,5 +1228,46 @@ suite('WebGPU p5.Shader', function() {
});
}
});

suite('compute shaders', () => {
test('handle early return in void compute hook', async () => {
await myp5.createCanvas(5, 5, myp5.WEBGPU);

// This test verifies that buildComputeShader and p5.compute
// correctly handle void hooks with early returns without crashing
// the strands compiler or hitting type errors.
expect(() => {
const computeShader = myp5.buildComputeShader(() => {
const id = myp5.index.x;
if (id > 10) {
return; // Early return in void hook
}
}, { myp5 });

myp5.compute(computeShader, 1);
}).not.toThrow();
});

test('early return in void compute hook stops execution', async () => {
await myp5.createCanvas(5, 5, myp5.WEBGPU);
const data = myp5.createStorage([0]);

const computeShader = myp5.buildComputeShader(() => {
const buf = myp5.uniformStorage();
const id = myp5.index.x;
if (id == 0) {
buf[0] = 1.0;
return;
buf[0] = 2.0; // Should not execute
}
}, { myp5 });

computeShader.setUniform('buf', data);

expect(() => {
myp5.compute(computeShader, 1);
}).not.toThrow();
});
});
});
});
Loading