Skip to content

Commit dd27465

Browse files
authored
Merge pull request #8394 from processing/fix/strands-branching
Fix for strands branching
2 parents 1241837 + 223947f commit dd27465

File tree

7 files changed

+268
-110
lines changed

7 files changed

+268
-110
lines changed

src/strands/ir_types.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export const NodeTypeRequiredFields = {
2727
export const StatementType = {
2828
DISCARD: 'discard',
2929
BREAK: 'break',
30+
EARLY_RETURN: 'early_return',
3031
EXPRESSION: 'expression', // Used when we want to output a single expression as a statement, e.g. a for loop condition
3132
EMPTY: 'empty', // Used for empty statements like ; in for loops
3233
};

src/strands/strands_api.js

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import {
99
isStructType,
1010
OpCode,
1111
StatementType,
12+
NodeType,
1213
// isNativeType
1314
} from './ir_types'
1415
import { strandsBuiltinFunctions } from './strands_builtins'
1516
import { StrandsConditional } from './strands_conditionals'
1617
import { StrandsFor } from './strands_for'
1718
import * as CFG from './ir_cfg'
19+
import * as DAG from './ir_dag';
1820
import * as FES from './strands_FES'
1921
import { getNodeDataFromID } from './ir_dag'
2022
import { StrandsNode, createStrandsNode } from './strands_node'
@@ -63,6 +65,39 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
6365
return new StrandsFor(strandsContext, initialCb, conditionCb, updateCb, bodyCb, initialVars).build();
6466
};
6567
fn.strandsFor = p5.strandsFor;
68+
p5.strandsEarlyReturn = function(value) {
69+
const { dag, cfg } = strandsContext;
70+
71+
// Ensure we're inside a hook
72+
if (!strandsContext.activeHook) {
73+
throw new Error('strandsEarlyReturn can only be used inside a hook callback');
74+
}
75+
76+
// Convert value to a StrandsNode if it isn't already
77+
const valueNode = value instanceof StrandsNode ? value : p5.strandsNode(value);
78+
79+
// Create a new CFG block for the early return
80+
const earlyReturnBlockID = CFG.createBasicBlock(cfg, BlockType.DEFAULT);
81+
CFG.addEdge(cfg, cfg.currentBlock, earlyReturnBlockID);
82+
CFG.pushBlock(cfg, earlyReturnBlockID);
83+
84+
// Create the early return statement node
85+
const nodeData = DAG.createNodeData({
86+
nodeType: NodeType.STATEMENT,
87+
statementType: StatementType.EARLY_RETURN,
88+
dependsOn: [valueNode.id]
89+
});
90+
const earlyReturnID = DAG.getOrCreateNode(dag, nodeData);
91+
CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID);
92+
93+
// Add the value to the hook's earlyReturns array for later type checking
94+
strandsContext.activeHook.earlyReturns.push({ earlyReturnID, valueNode });
95+
96+
CFG.popBlock(cfg);
97+
98+
return valueNode;
99+
};
100+
fn.strandsEarlyReturn = p5.strandsEarlyReturn;
66101
p5.strandsNode = function(...args) {
67102
if (args.length === 1 && args[0] instanceof StrandsNode) {
68103
return args[0];
@@ -403,53 +438,62 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
403438
CFG.addEdge(cfg, cfg.currentBlock, entryBlockID);
404439
CFG.pushBlock(cfg, entryBlockID);
405440
const args = createHookArguments(strandsContext, hookType.parameters);
441+
strandsContext.activeHook = hookImplementation;
406442
const userReturned = hookUserCallback(...args);
443+
strandsContext.activeHook = undefined;
407444
const expectedReturnType = hookType.returnType;
408445
let rootNodeID = null;
409-
if(isStructType(expectedReturnType)) {
410-
const expectedStructType = structType(expectedReturnType);
411-
if (userReturned instanceof StrandsNode) {
412-
const returnedNode = getNodeDataFromID(strandsContext.dag, userReturned.id);
413-
if (returnedNode.baseType !== expectedStructType.typeName) {
414-
FES.userError("type error", `You have returned a ${userReturned.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
446+
const handleRetVal = (retNode) => {
447+
if(isStructType(expectedReturnType)) {
448+
const expectedStructType = structType(expectedReturnType);
449+
if (retNode instanceof StrandsNode) {
450+
const returnedNode = getNodeDataFromID(strandsContext.dag, retNode.id);
451+
if (returnedNode.baseType !== expectedStructType.typeName) {
452+
FES.userError("type error", `You have returned a ${retNode.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
453+
}
454+
const newDeps = returnedNode.dependsOn.slice();
455+
for (let i = 0; i < expectedStructType.properties.length; i++) {
456+
const expectedType = expectedStructType.properties[i].dataType;
457+
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[retNode.id], strandsContext);
458+
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
459+
}
460+
dag.dependsOn[retNode.id] = newDeps;
461+
return retNode.id;
415462
}
416-
const newDeps = returnedNode.dependsOn.slice();
417-
for (let i = 0; i < expectedStructType.properties.length; i++) {
418-
const expectedType = expectedStructType.properties[i].dataType;
419-
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[userReturned.id], strandsContext);
420-
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
463+
else {
464+
const expectedProperties = expectedStructType.properties;
465+
const newStructDependencies = [];
466+
for (let i = 0; i < expectedProperties.length; i++) {
467+
const expectedProp = expectedProperties[i];
468+
const propName = expectedProp.name;
469+
const receivedValue = retNode[propName];
470+
if (receivedValue === undefined) {
471+
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
472+
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
473+
`Received: { ${Object.keys(retNode).join(', ')} }\n` +
474+
`All of the properties are required!`);
475+
}
476+
const expectedTypeInfo = expectedProp.dataType;
477+
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
478+
newStructDependencies.push(returnedPropID);
479+
}
480+
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
481+
return newStruct.id;
421482
}
422-
dag.dependsOn[userReturned.id] = newDeps;
423-
rootNodeID = userReturned.id;
424483
}
425-
else {
426-
const expectedProperties = expectedStructType.properties;
427-
const newStructDependencies = [];
428-
for (let i = 0; i < expectedProperties.length; i++) {
429-
const expectedProp = expectedProperties[i];
430-
const propName = expectedProp.name;
431-
const receivedValue = userReturned[propName];
432-
if (receivedValue === undefined) {
433-
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
434-
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
435-
`Received: { ${Object.keys(userReturned).join(', ')} }\n` +
436-
`All of the properties are required!`);
437-
}
438-
const expectedTypeInfo = expectedProp.dataType;
439-
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
440-
newStructDependencies.push(returnedPropID);
484+
else /*if(isNativeType(expectedReturnType.typeName))*/ {
485+
if (!expectedReturnType.dataType) {
486+
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
441487
}
442-
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
443-
rootNodeID = newStruct.id;
488+
const expectedTypeInfo = expectedReturnType.dataType;
489+
return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name);
444490
}
445491
}
446-
else /*if(isNativeType(expectedReturnType.typeName))*/ {
447-
if (!expectedReturnType.dataType) {
448-
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
449-
}
450-
const expectedTypeInfo = expectedReturnType.dataType;
451-
rootNodeID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, userReturned, hookType.name);
492+
for (const { valueNode, earlyReturnID } of hookImplementation.earlyReturns) {
493+
const id = handleRetVal(valueNode);
494+
dag.dependsOn[earlyReturnID] = [id];
452495
}
496+
rootNodeID = userReturned ? handleRetVal(userReturned) : undefined;
453497
const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`;
454498
const hookInfo = availableHooks[fullHookName];
455499
strandsContext.hooks.push({
@@ -460,6 +504,7 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
460504
});
461505
CFG.popBlock(cfg);
462506
}
507+
hookImplementation.earlyReturns = [];
463508
strandsContext.windowOverrides[hookType.name] = window[hookType.name];
464509
strandsContext.fnOverrides[hookType.name] = fn[hookType.name];
465510
window[hookType.name] = hookImplementation;

src/strands/strands_codegen.js

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { sortCFG } from "./ir_cfg";
2-
import { structType, TypeInfoFromGLSLName } from './ir_types';
1+
import { sortCFG } from './ir_cfg';
2+
import * as DAG from './ir_dag';
3+
import { NodeType, StatementType, structType, TypeInfoFromGLSLName } from './ir_types';
34

45
export function generateShaderCode(strandsContext) {
56
const {
@@ -68,7 +69,10 @@ export function generateShaderCode(strandsContext) {
6869
}
6970
returnType = hookType.returnType.dataType;
7071
}
71-
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);
72+
73+
if (rootNodeID) {
74+
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);
75+
}
7276
hooksObj[`${hookType.returnType.typeName} ${hookType.name}`] = [firstLine, ...generationContext.codeLines, '}'].join('\n');
7377
}
7478

0 commit comments

Comments
 (0)