Skip to content

Commit 1e8e8d2

Browse files
committed
Fix usage of if statements in helper functions
1 parent 5defcba commit 1e8e8d2

7 files changed

Lines changed: 340 additions & 10 deletions

File tree

src/strands/ir_builders.js

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,17 @@ function mapPrimitiveDepsToIDs(strandsContext, typeInfo, dependsOn) {
239239
calculatedDimensions += dimension;
240240
continue;
241241
}
242+
else if (typeof dep === 'boolean') {
243+
// Handle boolean literals - convert to bool type
244+
const { id, dimension } = scalarLiteralNode(strandsContext, { dimension: 1, baseType: BaseType.BOOL }, dep);
245+
mappedDependencies.push(id);
246+
calculatedDimensions += dimension;
247+
// Update baseType to BOOL if it was inferred
248+
if (baseType !== BaseType.BOOL) {
249+
baseType = BaseType.BOOL;
250+
}
251+
continue;
252+
}
242253
else {
243254
FES.userError('type error', `You've tried to construct a scalar or vector type with a non-numeric value: ${dep}`);
244255
}
@@ -289,7 +300,12 @@ export function primitiveConstructorNode(strandsContext, typeInfo, dependsOn) {
289300
const { mappedDependencies, inferredTypeInfo } = mapPrimitiveDepsToIDs(strandsContext, typeInfo, dependsOn);
290301

291302
const finalType = {
292-
baseType: typeInfo.baseType,
303+
// We might have inferred a non numeric type. Currently this is
304+
// just used for booleans. Maybe this needs to be something more robust
305+
// if we ever want to support inference of e.g. int vectors?
306+
baseType: inferredTypeInfo.baseType === BaseType.BOOL
307+
? BaseType.BOOL
308+
: typeInfo.baseType,
293309
dimension: inferredTypeInfo.dimension
294310
};
295311

src/strands/ir_types.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export const BaseType = {
3737
BOOL: "bool",
3838
MAT: "mat",
3939
DEFER: "defer",
40+
ASSIGN_ON_USE: "assign_on_use",
4041
SAMPLER2D: "sampler2D",
4142
SAMPLER: "sampler",
4243
};
@@ -46,6 +47,7 @@ export const BasePriority = {
4647
[BaseType.BOOL]: 1,
4748
[BaseType.MAT]: 0,
4849
[BaseType.DEFER]: -1,
50+
[BaseType.ASSIGN_ON_USE]: -2,
4951
[BaseType.SAMPLER2D]: -10,
5052
[BaseType.SAMPLER]: -11,
5153
};
@@ -66,6 +68,7 @@ export const DataType = {
6668
mat3: { fnName: "mat3x3", baseType: BaseType.MAT, dimension:3, priority: 0, },
6769
mat4: { fnName: "mat4x4", baseType: BaseType.MAT, dimension:4, priority: 0, },
6870
defer: { fnName: null, baseType: BaseType.DEFER, dimension: null, priority: -1 },
71+
assign_on_use: { fnName: null, baseType: BaseType.ASSIGN_ON_USE, dimension: null, priority: -2 },
6972
sampler2D: { fnName: "sampler2D", baseType: BaseType.SAMPLER2D, dimension: 1, priority: -10 },
7073
sampler: { fnName: "sampler", baseType: BaseType.SAMPLER, dimension: 1, priority: -11 },
7174
}

src/strands/strands_api.js

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,20 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
198198
if (args.length > 4) {
199199
FES.userError("type error", "It looks like you've tried to construct a p5.strands node implicitly, with more than 4 components. This is currently not supported.")
200200
}
201-
const { id, dimension } = build.primitiveConstructorNode(strandsContext, { baseType: BaseType.FLOAT, dimension: null }, args.flat());
201+
// Filter out undefined/null values
202+
const flatArgs = args.flat();
203+
const definedArgs = flatArgs.filter(a => a !== undefined && a !== null);
204+
205+
// If all args are undefined, this is likely a `let myVar` at the
206+
// start of an if statement and it will be assigned within the branches.
207+
// For that, we use an assign-on-use node, meaning we'll take the type of the
208+
// values assigned to it.
209+
if (definedArgs.length === 0) {
210+
const { id, dimension } = build.primitiveConstructorNode(strandsContext, { baseType: BaseType.ASSIGN_ON_USE, dimension: null }, [0]);
211+
return createStrandsNode(id, dimension, strandsContext);
212+
}
213+
214+
const { id, dimension } = build.primitiveConstructorNode(strandsContext, { baseType: BaseType.FLOAT, dimension: null }, definedArgs);
202215
return createStrandsNode(id, dimension, strandsContext);//new StrandsNode(id, dimension, strandsContext);
203216
}
204217
//////////////////////////////////////////////
@@ -337,7 +350,7 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
337350
// variant or also one more directly translated from GLSL, or to be more compatible with
338351
// APIs we documented at the release of 2.x and have to continue supporting.
339352
for (const type in DataType) {
340-
if (type === BaseType.DEFER || type === 'sampler') {
353+
if (type === BaseType.DEFER || type === BaseType.ASSIGN_ON_USE || type === 'sampler') {
341354
continue;
342355
}
343356
const typeInfo = DataType[type];

src/strands/strands_phi_utils.js

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import * as CFG from './ir_cfg';
22
import * as DAG from './ir_dag';
3-
import { NodeType } from './ir_types';
3+
import { NodeType, BaseType } from './ir_types';
44

55
export function createPhiNode(strandsContext, phiInputs, varName) {
66
// Determine the proper dimension and baseType from the inputs
77
const validInputs = phiInputs.filter(input => input.value.id !== null);
88
if (validInputs.length === 0) {
99
throw new Error(`No valid inputs for phi node for variable ${varName}`);
1010
}
11-
// Get dimension and baseType from first valid input
12-
let firstInput = validInputs
13-
.map((input) => DAG.getNodeDataFromID(strandsContext.dag, input.value.id))
14-
.find((input) => input.dimension) ??
15-
DAG.getNodeDataFromID(strandsContext.dag, validInputs[0].value.id);
11+
12+
// Get dimension and baseType from first valid input, skipping ASSIGN_ON_USE nodes
13+
const inputNodes = validInputs.map((input) => DAG.getNodeDataFromID(strandsContext.dag, input.value.id));
14+
let firstInput = inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE && input.dimension) ??
15+
inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE) ??
16+
inputNodes[0];
17+
1618
const dimension = firstInput.dimension;
1719
const baseType = firstInput.baseType;
20+
1821
const nodeData = {
1922
nodeType: NodeType.PHI,
2023
dimension,

src/strands/strands_transpiler.js

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,177 @@ const ASTCallbacks = {
11421142
return replaceInNode(node);
11431143
}
11441144
}
1145+
1146+
// Helper function to check if a function body contains return statements in control flow
1147+
function functionHasEarlyReturns(functionNode) {
1148+
let hasEarlyReturn = false;
1149+
let inControlFlow = 0;
1150+
1151+
const checkForEarlyReturns = {
1152+
IfStatement(node, state, c) {
1153+
inControlFlow++;
1154+
if (node.test) c(node.test, state);
1155+
if (node.consequent) c(node.consequent, state);
1156+
if (node.alternate) c(node.alternate, state);
1157+
inControlFlow--;
1158+
},
1159+
ForStatement(node, state, c) {
1160+
inControlFlow++;
1161+
if (node.init) c(node.init, state);
1162+
if (node.test) c(node.test, state);
1163+
if (node.update) c(node.update, state);
1164+
if (node.body) c(node.body, state);
1165+
inControlFlow--;
1166+
},
1167+
ReturnStatement(node) {
1168+
if (inControlFlow > 0) {
1169+
hasEarlyReturn = true;
1170+
}
1171+
}
1172+
};
1173+
1174+
if (functionNode.body && functionNode.body.type === 'BlockStatement') {
1175+
recursive(functionNode.body, {}, checkForEarlyReturns);
1176+
}
1177+
1178+
return hasEarlyReturn;
1179+
}
1180+
1181+
// Helper function to check if an if-statement's consequent contains a return
1182+
function blockContainsReturn(block) {
1183+
let hasReturn = false;
1184+
const findReturn = {
1185+
ReturnStatement() {
1186+
hasReturn = true;
1187+
}
1188+
};
1189+
if (block) {
1190+
recursive(block, {}, findReturn);
1191+
}
1192+
return hasReturn;
1193+
}
1194+
1195+
// Transform a helper function to use __returnValue pattern instead of early returns
1196+
function transformHelperFunction(functionNode) {
1197+
// 1. Add __returnValue declaration at the start of function body
1198+
const returnValueDecl = {
1199+
type: 'VariableDeclaration',
1200+
declarations: [{
1201+
type: 'VariableDeclarator',
1202+
id: { type: 'Identifier', name: '__returnValue' },
1203+
init: null
1204+
}],
1205+
kind: 'let'
1206+
};
1207+
1208+
if (!functionNode.body || functionNode.body.type !== 'BlockStatement') {
1209+
return; // Can't transform arrow functions with expression bodies
1210+
}
1211+
1212+
functionNode.body.body.unshift(returnValueDecl);
1213+
1214+
// 2. Restructure if statements: move siblings after if with return into else block
1215+
function restructureIfStatements(statements) {
1216+
for (let i = 0; i < statements.length; i++) {
1217+
const stmt = statements[i];
1218+
1219+
if (stmt.type === 'IfStatement' && blockContainsReturn(stmt.consequent) && !stmt.alternate) {
1220+
// Find all subsequent statements
1221+
const subsequentStatements = statements.slice(i + 1);
1222+
1223+
if (subsequentStatements.length > 0) {
1224+
// Create else block with subsequent statements
1225+
stmt.alternate = {
1226+
type: 'BlockStatement',
1227+
body: subsequentStatements
1228+
};
1229+
1230+
// Remove the subsequent statements from this level
1231+
statements.splice(i + 1);
1232+
1233+
// Recursively process the new else block
1234+
restructureIfStatements(stmt.alternate.body);
1235+
}
1236+
}
1237+
1238+
// Recursively process nested blocks
1239+
if (stmt.type === 'IfStatement') {
1240+
if (stmt.consequent && stmt.consequent.type === 'BlockStatement') {
1241+
restructureIfStatements(stmt.consequent.body);
1242+
}
1243+
if (stmt.alternate && stmt.alternate.type === 'BlockStatement') {
1244+
restructureIfStatements(stmt.alternate.body);
1245+
}
1246+
} else if (stmt.type === 'ForStatement' && stmt.body && stmt.body.type === 'BlockStatement') {
1247+
restructureIfStatements(stmt.body.body);
1248+
} else if (stmt.type === 'BlockStatement') {
1249+
restructureIfStatements(stmt.body);
1250+
}
1251+
}
1252+
}
1253+
1254+
restructureIfStatements(functionNode.body.body);
1255+
1256+
// 3. Transform all return statements to assignments
1257+
const transformReturns = {
1258+
ReturnStatement(node) {
1259+
// Convert return statement to assignment
1260+
node.type = 'ExpressionStatement';
1261+
node.expression = {
1262+
type: 'AssignmentExpression',
1263+
operator: '=',
1264+
left: { type: 'Identifier', name: '__returnValue' },
1265+
right: node.argument || { type: 'Identifier', name: 'undefined' }
1266+
};
1267+
delete node.argument;
1268+
}
1269+
};
1270+
1271+
recursive(functionNode.body, {}, transformReturns);
1272+
1273+
// 4. Add final return statement
1274+
const finalReturn = {
1275+
type: 'ReturnStatement',
1276+
argument: { type: 'Identifier', name: '__returnValue' }
1277+
};
1278+
1279+
functionNode.body.body.push(finalReturn);
1280+
}
1281+
1282+
// Main transformation pass: find and transform helper functions with early returns
1283+
function transformHelperFunctionEarlyReturns(ast) {
1284+
const helperFunctionsToTransform = [];
1285+
1286+
// Collect helper functions that need transformation
1287+
const collectHelperFunctions = {
1288+
VariableDeclarator(node, ancestors) {
1289+
const init = node.init;
1290+
if (init && (init.type === 'ArrowFunctionExpression' || init.type === 'FunctionExpression')) {
1291+
if (functionHasEarlyReturns(init)) {
1292+
helperFunctionsToTransform.push(init);
1293+
}
1294+
}
1295+
},
1296+
FunctionDeclaration(node, ancestors) {
1297+
if (functionHasEarlyReturns(node)) {
1298+
helperFunctionsToTransform.push(node);
1299+
}
1300+
},
1301+
// Don't transform functions that are direct arguments to call expressions
1302+
CallExpression(node, ancestors) {
1303+
// Arguments to CallExpressions are base callbacks, not helpers
1304+
// We skip them by not adding them to the transformation list
1305+
}
1306+
};
1307+
1308+
ancestor(ast, collectHelperFunctions);
1309+
1310+
// Transform each collected helper function
1311+
for (const funcNode of helperFunctionsToTransform) {
1312+
transformHelperFunction(funcNode);
1313+
}
1314+
}
1315+
11451316
export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) {
11461317
// Reset counters at the start of each transpilation
11471318
blockVarCounter = 0;
@@ -1156,7 +1327,11 @@ const ASTCallbacks = {
11561327
delete nonControlFlowCallbacks.IfStatement;
11571328
delete nonControlFlowCallbacks.ForStatement;
11581329
ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {} });
1159-
// Second pass: transform if/for statements in post-order using recursive traversal
1330+
1331+
// Second pass: transform helper functions with early returns to use __returnValue pattern
1332+
transformHelperFunctionEarlyReturns(ast);
1333+
1334+
// Third pass: transform if/for statements in post-order using recursive traversal
11601335
const postOrderControlFlowTransform = {
11611336
IfStatement(node, state, c) {
11621337
state.inControlFlow++;

test/unit/webgl/p5.Shader.js

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,66 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca
11421142
assert.approximately(pixelColor[1], 0, 5);
11431143
assert.approximately(pixelColor[2], 0, 5);
11441144
});
1145+
1146+
test('using boolean intermediate variables in functions', () => {
1147+
myp5.createCanvas(50, 50, myp5.WEBGL);
1148+
1149+
const testShader = myp5.baseFilterShader().modify(() => {
1150+
const conditionMet = () => {
1151+
let condition = 1 > 2;
1152+
let value = 1;
1153+
if (value < 0.5) {
1154+
condition = 0.5 < 2;
1155+
}
1156+
return !condition
1157+
}
1158+
myp5.getColor((inputs, canvasContent) => {
1159+
if (conditionMet()) {
1160+
return [1, 0, 0, 1]
1161+
}
1162+
1163+
return [0.4, 0, 0, 1];
1164+
});
1165+
}, { myp5 });
1166+
1167+
myp5.background(255, 255, 255);
1168+
myp5.filter(testShader);
1169+
1170+
const pixelColor = myp5.get(25, 25);
1171+
assert.approximately(pixelColor[0], 255, 5);
1172+
assert.approximately(pixelColor[1], 0, 5);
1173+
assert.approximately(pixelColor[2], 0, 5);
1174+
});
1175+
1176+
test('using boolean intermediate variables in functions with early returns', () => {
1177+
myp5.createCanvas(50, 50, myp5.WEBGL);
1178+
1179+
const testShader = myp5.baseFilterShader().modify(() => {
1180+
const conditionMet = () => {
1181+
let value = 1;
1182+
if (value < 0.5) {
1183+
return true
1184+
}
1185+
return false
1186+
}
1187+
myp5.getColor((inputs, canvasContent) => {
1188+
if (conditionMet()) {
1189+
return [1, 0, 0, 1]
1190+
}
1191+
1192+
return [0.4, 0, 0, 1];
1193+
});
1194+
}, { myp5 });
1195+
console.log(testShader.fragSrc())
1196+
1197+
myp5.background(255, 255, 255);
1198+
myp5.filter(testShader);
1199+
1200+
const pixelColor = myp5.get(25, 25);
1201+
assert.approximately(pixelColor[0], 102, 5);
1202+
assert.approximately(pixelColor[1], 0, 5);
1203+
assert.approximately(pixelColor[2], 0, 5);
1204+
});
11451205
});
11461206

11471207
suite('for loop statements', () => {

0 commit comments

Comments
 (0)