Skip to content

Commit 9b3db12

Browse files
committed
Handle early returns
1 parent a8e04a4 commit 9b3db12

7 files changed

Lines changed: 207 additions & 48 deletions

File tree

src/strands/ir_types.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export const NodeTypeRequiredFields = {
2727
export const StatementType = {
2828
DISCARD: 'discard',
2929
BREAK: 'break',
30+
EARLY_RETURN: 'early_return',
3031
EXPRESSION: 'expression', // Used when we want to output a single expression as a statement, e.g. a for loop condition
3132
EMPTY: 'empty', // Used for empty statements like ; in for loops
3233
};

src/strands/strands_api.js

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import {
99
isStructType,
1010
OpCode,
1111
StatementType,
12+
NodeType,
1213
// isNativeType
1314
} from './ir_types'
1415
import { strandsBuiltinFunctions } from './strands_builtins'
1516
import { StrandsConditional } from './strands_conditionals'
1617
import { StrandsFor } from './strands_for'
1718
import * as CFG from './ir_cfg'
19+
import * as DAG from './ir_dag';
1820
import * as FES from './strands_FES'
1921
import { getNodeDataFromID } from './ir_dag'
2022
import { StrandsNode, createStrandsNode } from './strands_node'
@@ -63,6 +65,39 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
6365
return new StrandsFor(strandsContext, initialCb, conditionCb, updateCb, bodyCb, initialVars).build();
6466
};
6567
fn.strandsFor = p5.strandsFor;
68+
p5.strandsEarlyReturn = function(value) {
69+
const { dag, cfg } = strandsContext;
70+
71+
// Ensure we're inside a hook
72+
if (!strandsContext.activeHook) {
73+
throw new Error('strandsEarlyReturn can only be used inside a hook callback');
74+
}
75+
76+
// Convert value to a StrandsNode if it isn't already
77+
const valueNode = value instanceof StrandsNode ? value : p5.strandsNode(value);
78+
79+
// Create a new CFG block for the early return
80+
const earlyReturnBlockID = CFG.createBasicBlock(cfg, BlockType.DEFAULT);
81+
CFG.addEdge(cfg, cfg.currentBlock, earlyReturnBlockID);
82+
CFG.pushBlock(cfg, earlyReturnBlockID);
83+
84+
// Create the early return statement node
85+
const nodeData = DAG.createNodeData({
86+
nodeType: NodeType.STATEMENT,
87+
statementType: StatementType.EARLY_RETURN,
88+
dependsOn: [valueNode.id]
89+
});
90+
const earlyReturnID = DAG.getOrCreateNode(dag, nodeData);
91+
CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID);
92+
93+
// Add the value to the hook's earlyReturns array for later type checking
94+
strandsContext.activeHook.earlyReturns.push({ earlyReturnID, valueNode });
95+
96+
CFG.popBlock(cfg);
97+
98+
return valueNode;
99+
};
100+
fn.strandsEarlyReturn = p5.strandsEarlyReturn;
66101
p5.strandsNode = function(...args) {
67102
if (args.length === 1 && args[0] instanceof StrandsNode) {
68103
return args[0];
@@ -403,53 +438,62 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
403438
CFG.addEdge(cfg, cfg.currentBlock, entryBlockID);
404439
CFG.pushBlock(cfg, entryBlockID);
405440
const args = createHookArguments(strandsContext, hookType.parameters);
441+
strandsContext.activeHook = hookImplementation;
406442
const userReturned = hookUserCallback(...args);
443+
strandsContext.activeHook = undefined;
407444
const expectedReturnType = hookType.returnType;
408445
let rootNodeID = null;
409-
if(isStructType(expectedReturnType)) {
410-
const expectedStructType = structType(expectedReturnType);
411-
if (userReturned instanceof StrandsNode) {
412-
const returnedNode = getNodeDataFromID(strandsContext.dag, userReturned.id);
413-
if (returnedNode.baseType !== expectedStructType.typeName) {
414-
FES.userError("type error", `You have returned a ${userReturned.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
446+
const handleRetVal = (retNode) => {
447+
if(isStructType(expectedReturnType)) {
448+
const expectedStructType = structType(expectedReturnType);
449+
if (retNode instanceof StrandsNode) {
450+
const returnedNode = getNodeDataFromID(strandsContext.dag, retNode.id);
451+
if (returnedNode.baseType !== expectedStructType.typeName) {
452+
FES.userError("type error", `You have returned a ${retNode.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
453+
}
454+
const newDeps = returnedNode.dependsOn.slice();
455+
for (let i = 0; i < expectedStructType.properties.length; i++) {
456+
const expectedType = expectedStructType.properties[i].dataType;
457+
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[retNode.id], strandsContext);
458+
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
459+
}
460+
dag.dependsOn[retNode.id] = newDeps;
461+
return retNode.id;
415462
}
416-
const newDeps = returnedNode.dependsOn.slice();
417-
for (let i = 0; i < expectedStructType.properties.length; i++) {
418-
const expectedType = expectedStructType.properties[i].dataType;
419-
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[userReturned.id], strandsContext);
420-
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
463+
else {
464+
const expectedProperties = expectedStructType.properties;
465+
const newStructDependencies = [];
466+
for (let i = 0; i < expectedProperties.length; i++) {
467+
const expectedProp = expectedProperties[i];
468+
const propName = expectedProp.name;
469+
const receivedValue = retNode[propName];
470+
if (receivedValue === undefined) {
471+
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
472+
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
473+
`Received: { ${Object.keys(retNode).join(', ')} }\n` +
474+
`All of the properties are required!`);
475+
}
476+
const expectedTypeInfo = expectedProp.dataType;
477+
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
478+
newStructDependencies.push(returnedPropID);
479+
}
480+
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
481+
return newStruct.id;
421482
}
422-
dag.dependsOn[userReturned.id] = newDeps;
423-
rootNodeID = userReturned.id;
424483
}
425-
else {
426-
const expectedProperties = expectedStructType.properties;
427-
const newStructDependencies = [];
428-
for (let i = 0; i < expectedProperties.length; i++) {
429-
const expectedProp = expectedProperties[i];
430-
const propName = expectedProp.name;
431-
const receivedValue = userReturned[propName];
432-
if (receivedValue === undefined) {
433-
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
434-
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
435-
`Received: { ${Object.keys(userReturned).join(', ')} }\n` +
436-
`All of the properties are required!`);
437-
}
438-
const expectedTypeInfo = expectedProp.dataType;
439-
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
440-
newStructDependencies.push(returnedPropID);
484+
else /*if(isNativeType(expectedReturnType.typeName))*/ {
485+
if (!expectedReturnType.dataType) {
486+
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
441487
}
442-
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
443-
rootNodeID = newStruct.id;
488+
const expectedTypeInfo = expectedReturnType.dataType;
489+
return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name);
444490
}
445491
}
446-
else /*if(isNativeType(expectedReturnType.typeName))*/ {
447-
if (!expectedReturnType.dataType) {
448-
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
449-
}
450-
const expectedTypeInfo = expectedReturnType.dataType;
451-
rootNodeID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, userReturned, hookType.name);
492+
for (const { valueNode, earlyReturnID } of hookImplementation.earlyReturns) {
493+
const id = handleRetVal(valueNode);
494+
dag.dependsOn[earlyReturnID] = [id];
452495
}
496+
rootNodeID = userReturned ? handleRetVal(userReturned) : undefined;
453497
const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`;
454498
const hookInfo = availableHooks[fullHookName];
455499
strandsContext.hooks.push({
@@ -460,6 +504,7 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
460504
});
461505
CFG.popBlock(cfg);
462506
}
507+
hookImplementation.earlyReturns = [];
463508
strandsContext.windowOverrides[hookType.name] = window[hookType.name];
464509
strandsContext.fnOverrides[hookType.name] = fn[hookType.name];
465510
window[hookType.name] = hookImplementation;

src/strands/strands_codegen.js

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { sortCFG } from "./ir_cfg";
2-
import { structType, TypeInfoFromGLSLName } from './ir_types';
1+
import { sortCFG } from './ir_cfg';
2+
import * as DAG from './ir_dag';
3+
import { NodeType, StatementType, structType, TypeInfoFromGLSLName } from './ir_types';
34

45
export function generateShaderCode(strandsContext) {
56
const {
@@ -68,7 +69,10 @@ export function generateShaderCode(strandsContext) {
6869
}
6970
returnType = hookType.returnType.dataType;
7071
}
71-
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);
72+
73+
if (rootNodeID) {
74+
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);
75+
}
7276
hooksObj[`${hookType.returnType.typeName} ${hookType.name}`] = [firstLine, ...generationContext.codeLines, '}'].join('\n');
7377
}
7478

src/strands/strands_transpiler.js

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,23 @@ const ASTCallbacks = {
201201
},
202202
AssignmentExpression(node, _state, ancestors) {
203203
if (ancestors.some(nodeIsUniform)) { return; }
204+
const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier'];
204205
if (node.operator !== '=') {
205206
const methodName = replaceBinaryOperator(node.operator.replace('=',''));
206207
const rightReplacementNode = {
207208
type: 'CallExpression',
208209
callee: {
209210
type: 'MemberExpression',
210-
object: node.left,
211+
object: unsafeTypes.includes(node.left.type)
212+
? {
213+
type: 'CallExpression',
214+
callee: {
215+
type: 'Identifier',
216+
name: '__p5.strandsNode',
217+
},
218+
arguments: [node.left]
219+
}
220+
: node.left,
211221
property: {
212222
type: 'Identifier',
213223
name: methodName,
@@ -402,6 +412,7 @@ const ASTCallbacks = {
402412
},
403413
arguments: [elseFunction]
404414
};
415+
405416
// Analyze which outer scope variables are assigned in any branch
406417
const assignedVars = new Set();
407418

@@ -1007,24 +1018,42 @@ const ASTCallbacks = {
10071018
// Second pass: transform if/for statements in post-order using recursive traversal
10081019
const postOrderControlFlowTransform = {
10091020
IfStatement(node, state, c) {
1021+
state.inControlFlow++;
10101022
// First recursively process children
10111023
if (node.test) c(node.test, state);
10121024
if (node.consequent) c(node.consequent, state);
10131025
if (node.alternate) c(node.alternate, state);
10141026
// Then apply the transformation to this node
10151027
ASTCallbacks.IfStatement(node, state, []);
1028+
state.inControlFlow--;
10161029
},
10171030
ForStatement(node, state, c) {
1031+
state.inControlFlow++;
10181032
// First recursively process children
10191033
if (node.init) c(node.init, state);
10201034
if (node.test) c(node.test, state);
10211035
if (node.update) c(node.update, state);
10221036
if (node.body) c(node.body, state);
10231037
// Then apply the transformation to this node
10241038
ASTCallbacks.ForStatement(node, state, []);
1039+
state.inControlFlow--;
1040+
},
1041+
ReturnStatement(node, state, c) {
1042+
if (!state.inControlFlow) return;
1043+
// Convert return statement to strandsEarlyReturn call
1044+
node.type = 'ExpressionStatement';
1045+
node.expression = {
1046+
type: 'CallExpression',
1047+
callee: {
1048+
type: 'Identifier',
1049+
name: '__p5.strandsEarlyReturn'
1050+
},
1051+
arguments: node.argument ? [node.argument] : []
1052+
};
1053+
delete node.argument;
10251054
}
10261055
};
1027-
recursive(ast, { varyings: {} }, postOrderControlFlowTransform);
1056+
recursive(ast, { varyings: {}, inControlFlow: 0 }, postOrderControlFlowTransform);
10281057
const transpiledSource = escodegen.generate(ast);
10291058
const scopeKeys = Object.keys(scope);
10301059
const match = /\(?\s*(?:function)?\s*\(([^)]*)\)\s*(?:=>)?\s*{((?:.|\n)*)}\s*;?\s*\)?/

src/webgl/strands_glslBackend.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,22 @@ export const glslBackend = {
188188
},
189189
generateStatement(generationContext, dag, nodeID) {
190190
const node = getNodeDataFromID(dag, nodeID);
191+
// Generate the expression followed by semicolon (unless suppressed)
191192
const semicolon = generationContext.suppressSemicolon ? '' : ';';
192193
if (node.statementType === StatementType.DISCARD) {
193194
generationContext.write(`discard${semicolon}`);
194195
} else if (node.statementType === StatementType.BREAK) {
195196
generationContext.write(`break${semicolon}`);
196197
} else if (node.statementType === StatementType.EXPRESSION) {
197-
// Generate the expression followed by semicolon (unless suppressed)
198198
const exprNodeID = node.dependsOn[0];
199199
const expr = this.generateExpression(generationContext, dag, exprNodeID);
200200
generationContext.write(`${expr}${semicolon}`);
201201
} else if (node.statementType === StatementType.EMPTY) {
202-
// Generate just a semicolon (unless suppressed)
203202
generationContext.write(semicolon);
203+
} else if (node.statementType === StatementType.EARLY_RETURN) {
204+
const exprNodeID = node.dependsOn[0];
205+
const expr = this.generateExpression(generationContext, dag, exprNodeID);
206+
generationContext.write(`return ${expr}${semicolon}`);
204207
}
205208
},
206209
generateAssignment(generationContext, dag, nodeID) {

src/webgpu/strands_wgslBackend.js

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,23 @@ export const wgslBackend = {
237237
},
238238
generateStatement(generationContext, dag, nodeID) {
239239
const node = getNodeDataFromID(dag, nodeID);
240+
// Generate the expression followed by semicolon (unless suppressed)
240241
const semicolon = generationContext.suppressSemicolon ? '' : ';';
241242
if (node.statementType === StatementType.DISCARD) {
242243
generationContext.write(`discard${semicolon}`);
243244
} else if (node.statementType === StatementType.BREAK) {
244245
generationContext.write(`break${semicolon}`);
245246
} else if (node.statementType === StatementType.EXPRESSION) {
246-
// Generate the expression followed by semicolon (unless suppressed)
247247
const exprNodeID = node.dependsOn[0];
248248
const expr = this.generateExpression(generationContext, dag, exprNodeID);
249249
generationContext.write(`${expr}${semicolon}`);
250250
} else if (node.statementType === StatementType.EMPTY) {
251251
// Generate just a semicolon (unless suppressed)
252252
generationContext.write(semicolon);
253+
} else if (node.statementType === StatementType.EARLY_RETURN) {
254+
const exprNodeID = node.dependsOn[0];
255+
const expr = this.generateExpression(generationContext, dag, exprNodeID);
256+
generationContext.write(`return ${expr}${semicolon}`);
253257
}
254258
},
255259
generateAssignment(generationContext, dag, nodeID) {

0 commit comments

Comments
 (0)