Skip to content

Commit 1f0816d

Browse files
authored
Merge pull request #8709 from processing/fix/uniform-callbacks
Make sure we don't transpile uniform callbacks
2 parents a1cd2f9 + de29d3b commit 1f0816d

2 files changed

Lines changed: 308 additions & 40 deletions

File tree

src/strands/strands_transpiler.js

Lines changed: 146 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,50 @@ function nodeIsUniform(ancestor) {
4040
);
4141
}
4242

43+
function nodeIsUniformCallbackFn(node, names) {
44+
if (!names?.size) return false;
45+
if (node.type === 'FunctionDeclaration' && names.has(node.id?.name)) return true;
46+
if (
47+
node.type === 'VariableDeclarator' && names.has(node.id?.name) &&
48+
(node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression')
49+
) {
50+
return true;
51+
}
52+
return false;
53+
}
54+
55+
function collectUniformCallbackNames(ast) {
56+
// Sub-pass 1: collect all named function definitions
57+
const namedFunctions = new Set();
58+
ancestor(ast, {
59+
FunctionDeclaration(node) {
60+
if (node.id) namedFunctions.add(node.id.name);
61+
},
62+
VariableDeclarator(node) {
63+
if (
64+
node.id?.type === 'Identifier' &&
65+
(node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression')
66+
) {
67+
namedFunctions.add(node.id.name);
68+
}
69+
}
70+
});
71+
// Sub-pass 2: find which of those names are passed as uniform call arguments
72+
const names = new Set();
73+
ancestor(ast, {
74+
CallExpression(node) {
75+
if (nodeIsUniform(node)) {
76+
for (const arg of node.arguments) {
77+
if (arg.type === 'Identifier' && namedFunctions.has(arg.name)) {
78+
names.add(arg.name);
79+
}
80+
}
81+
}
82+
}
83+
});
84+
return names;
85+
}
86+
4387
function nodeIsVarying(node) {
4488
return node && node.type === 'CallExpression'
4589
&& (
@@ -192,8 +236,10 @@ function replaceReferences(node, tempVarMap) {
192236
}
193237

194238
const ASTCallbacks = {
195-
UnaryExpression(node, _state, ancestors) {
196-
if (ancestors.some(nodeIsUniform)) { return; }
239+
UnaryExpression(node, state, ancestors) {
240+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
241+
return;
242+
}
197243
const unaryFnName = UnarySymbolToName[node.operator];
198244
const standardReplacement = (node) => {
199245
node.type = 'CallExpression'
@@ -236,17 +282,21 @@ const ASTCallbacks = {
236282
delete node.argument;
237283
delete node.operator;
238284
},
239-
BreakStatement(node, _state, ancestors) {
240-
if (ancestors.some(nodeIsUniform)) { return; }
285+
BreakStatement(node, state, ancestors) {
286+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
287+
return;
288+
}
241289
node.callee = {
242290
type: 'Identifier',
243291
name: '__p5.break'
244292
};
245293
node.arguments = [];
246294
node.type = 'CallExpression';
247295
},
248-
MemberExpression(node, _state, ancestors) {
249-
if (ancestors.some(nodeIsUniform)) { return; }
296+
MemberExpression(node, state, ancestors) {
297+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
298+
return;
299+
}
250300
// Skip sets -- these will be converted to .set() method
251301
// calls at the AssignmentExpression level
252302
if (
@@ -272,8 +322,10 @@ const ASTCallbacks = {
272322
node.type = 'CallExpression';
273323
}
274324
},
275-
VariableDeclarator(node, _state, ancestors) {
276-
if (ancestors.some(nodeIsUniform)) { return; }
325+
VariableDeclarator(node, state, ancestors) {
326+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
327+
return;
328+
}
277329
if (nodeIsUniform(node.init)) {
278330
// Only inject the variable name if the first argument isn't already a string
279331
if (node.init.arguments.length === 0 ||
@@ -298,16 +350,18 @@ const ASTCallbacks = {
298350
value: node.id.name
299351
}
300352
node.init.arguments.unshift(varyingNameLiteral);
301-
_state.varyings[node.id.name] = varyingNameLiteral;
353+
state.varyings[node.id.name] = varyingNameLiteral;
302354
} else {
303355
// Still track it as a varying even if name wasn't injected
304-
_state.varyings[node.id.name] = node.init.arguments[0];
356+
state.varyings[node.id.name] = node.init.arguments[0];
305357
}
306358
}
307359
},
308-
Identifier(node, _state, ancestors) {
309-
if (ancestors.some(nodeIsUniform)) { return; }
310-
if (_state.varyings[node.name]
360+
Identifier(node, state, ancestors) {
361+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
362+
return;
363+
}
364+
if (state.varyings[node.name]
311365
&& !ancestors.some(a => a.type === 'AssignmentExpression' && a.left === node)
312366
) {
313367
node.type = 'CallExpression';
@@ -327,8 +381,10 @@ const ASTCallbacks = {
327381
},
328382
// The callbacks for AssignmentExpression and BinaryExpression handle
329383
// operator overloading including +=, *= assignment expressions
330-
ArrayExpression(node, _state, ancestors) {
331-
if (ancestors.some(nodeIsUniform)) { return; }
384+
ArrayExpression(node, state, ancestors) {
385+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
386+
return;
387+
}
332388
const original = JSON.parse(JSON.stringify(node));
333389
node.type = 'CallExpression';
334390
node.callee = {
@@ -337,8 +393,10 @@ const ASTCallbacks = {
337393
};
338394
node.arguments = [original];
339395
},
340-
AssignmentExpression(node, _state, ancestors) {
341-
if (ancestors.some(nodeIsUniform)) { return; }
396+
AssignmentExpression(node, state, ancestors) {
397+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
398+
return;
399+
}
342400
const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier'];
343401
if (node.operator !== '=') {
344402
const methodName = replaceBinaryOperator(node.operator.replace('=',''));
@@ -367,7 +425,7 @@ const ASTCallbacks = {
367425
node.right = rightReplacementNode;
368426
}
369427
// Handle direct varying variable assignment: myVarying = value
370-
if (_state.varyings[node.left.name]) {
428+
if (state.varyings[node.left.name]) {
371429
node.type = 'ExpressionStatement';
372430
node.expression = {
373431
type: 'CallExpression',
@@ -412,15 +470,15 @@ const ASTCallbacks = {
412470
let varyingName = null;
413471

414472
// Check if it's a direct identifier: myVarying.xyz
415-
if (node.left.object.type === 'Identifier' && _state.varyings[node.left.object.name]) {
473+
if (node.left.object.type === 'Identifier' && state.varyings[node.left.object.name]) {
416474
varyingName = node.left.object.name;
417475
}
418476
// Check if it's a getValue() call: myVarying.getValue().xyz
419477
else if (node.left.object.type === 'CallExpression' &&
420478
node.left.object.callee?.type === 'MemberExpression' &&
421479
node.left.object.callee.property?.name === 'getValue' &&
422480
node.left.object.callee.object?.type === 'Identifier' &&
423-
_state.varyings[node.left.object.callee.object.name]) {
481+
state.varyings[node.left.object.callee.object.name]) {
424482
varyingName = node.left.object.callee.object.name;
425483
}
426484

@@ -451,10 +509,12 @@ const ASTCallbacks = {
451509
}
452510
}
453511
},
454-
BinaryExpression(node, _state, ancestors) {
512+
BinaryExpression(node, state, ancestors) {
455513
// Don't convert uniform default values to node methods, as
456514
// they should be evaluated at runtime, not compiled.
457-
if (ancestors.some(nodeIsUniform)) { return; }
515+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
516+
return;
517+
}
458518
// If the left hand side of an expression is one of these types,
459519
// we should construct a node from it.
460520
const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier'];
@@ -482,10 +542,12 @@ const ASTCallbacks = {
482542
};
483543
node.arguments = [node.right];
484544
},
485-
LogicalExpression(node, _state, ancestors) {
545+
LogicalExpression(node, state, ancestors) {
486546
// Don't convert uniform default values to node methods, as
487547
// they should be evaluated at runtime, not compiled.
488-
if (ancestors.some(nodeIsUniform)) { return; }
548+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
549+
return;
550+
}
489551
// If the left hand side of an expression is one of these types,
490552
// we should construct a node from it.
491553
const unsafeTypes = ['Literal', 'ArrayExpression', 'Identifier'];
@@ -513,8 +575,10 @@ const ASTCallbacks = {
513575
};
514576
node.arguments = [node.right];
515577
},
516-
ConditionalExpression(node, _state, ancestors) {
517-
if (ancestors.some(nodeIsUniform)) { return; }
578+
ConditionalExpression(node, state, ancestors) {
579+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
580+
return;
581+
}
518582
// Transform condition ? consequent : alternate
519583
// into __p5.strandsTernary(condition, consequent, alternate)
520584
const test = node.test;
@@ -527,8 +591,10 @@ const ASTCallbacks = {
527591
delete node.consequent;
528592
delete node.alternate;
529593
},
530-
IfStatement(node, _state, ancestors) {
531-
if (ancestors.some(nodeIsUniform)) { return; }
594+
IfStatement(node, state, ancestors) {
595+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
596+
return;
597+
}
532598
// Transform if statement into strandsIf() call
533599
// The condition is evaluated directly, not wrapped in a function
534600
const condition = node.test;
@@ -796,8 +862,10 @@ const ASTCallbacks = {
796862
delete node.consequent;
797863
delete node.alternate;
798864
},
799-
UpdateExpression(node, _state, ancestors) {
800-
if (ancestors.some(nodeIsUniform)) { return; }
865+
UpdateExpression(node, state, ancestors) {
866+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
867+
return;
868+
}
801869

802870
// Transform ++var, var++, --var, var-- into assignment expressions
803871
let operator;
@@ -828,11 +896,13 @@ const ASTCallbacks = {
828896
// Replace the update expression with the assignment expression
829897
Object.assign(node, assignmentExpr);
830898
delete node.prefix;
831-
this.BinaryExpression(node.right, _state, [...ancestors, node]);
832-
this.AssignmentExpression(node, _state, ancestors);
899+
this.BinaryExpression(node.right, state, [...ancestors, node]);
900+
this.AssignmentExpression(node, state, ancestors);
833901
},
834-
ForStatement(node, _state, ancestors) {
835-
if (ancestors.some(nodeIsUniform)) { return; }
902+
ForStatement(node, state, ancestors) {
903+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, state.uniformCallbackNames))) {
904+
return;
905+
}
836906

837907
// Transform for statement into strandsFor() call
838908
// for (init; test; update) body -> strandsFor(initCb, conditionCb, updateCb, bodyCb, initialVars)
@@ -1538,22 +1608,31 @@ function transformFunctionSetCalls(functionNode) {
15381608
}
15391609

15401610
// Main transformation pass: find and transform functions with .set() calls in control flow
1541-
function transformSetCallsInControlFlow(ast) {
1611+
function transformSetCallsInControlFlow(ast, names) {
15421612
const functionsToTransform = [];
15431613

15441614
// Collect functions that have .set() calls in control flow
15451615
const collectFunctions = {
15461616
ArrowFunctionExpression(node, ancestors) {
1617+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) {
1618+
return;
1619+
}
15471620
if (functionHasSetInControlFlow(node)) {
15481621
functionsToTransform.push(node);
15491622
}
15501623
},
15511624
FunctionExpression(node, ancestors) {
1625+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) {
1626+
return;
1627+
}
15521628
if (functionHasSetInControlFlow(node)) {
15531629
functionsToTransform.push(node);
15541630
}
15551631
},
15561632
FunctionDeclaration(node, ancestors) {
1633+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) {
1634+
return;
1635+
}
15571636
if (functionHasSetInControlFlow(node)) {
15581637
functionsToTransform.push(node);
15591638
}
@@ -1569,12 +1648,15 @@ function transformSetCallsInControlFlow(ast) {
15691648
}
15701649

15711650
// Main transformation pass: find and transform helper functions with early returns
1572-
function transformHelperFunctionEarlyReturns(ast) {
1651+
function transformHelperFunctionEarlyReturns(ast, names) {
15731652
const helperFunctionsToTransform = [];
15741653

15751654
// Collect helper functions that need transformation
15761655
const collectHelperFunctions = {
15771656
VariableDeclarator(node, ancestors) {
1657+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) {
1658+
return;
1659+
}
15781660
const init = node.init;
15791661
if (init && (init.type === 'ArrowFunctionExpression' || init.type === 'FunctionExpression')) {
15801662
if (functionHasEarlyReturns(init)) {
@@ -1583,6 +1665,9 @@ function transformHelperFunctionEarlyReturns(ast) {
15831665
}
15841666
},
15851667
FunctionDeclaration(node, ancestors) {
1668+
if (ancestors.some(a => nodeIsUniform(a) || nodeIsUniformCallbackFn(a, names))) {
1669+
return;
1670+
}
15861671
if (functionHasEarlyReturns(node)) {
15871672
helperFunctionsToTransform.push(node);
15881673
}
@@ -1612,20 +1697,41 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) {
16121697
locations: srcLocations
16131698
});
16141699

1700+
// Pre-pass: collect names of functions passed by reference as uniform callbacks
1701+
const uniformCallbackNames = collectUniformCallbackNames(ast);
1702+
16151703
// First pass: transform .set() calls in control flow to use intermediate variables
1616-
transformSetCallsInControlFlow(ast);
1704+
transformSetCallsInControlFlow(ast, uniformCallbackNames);
16171705

16181706
// Second pass: transform everything except if/for statements using normal ancestor traversal
16191707
const nonControlFlowCallbacks = { ...ASTCallbacks };
16201708
delete nonControlFlowCallbacks.IfStatement;
16211709
delete nonControlFlowCallbacks.ForStatement;
1622-
ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {} });
1710+
ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {}, uniformCallbackNames });
16231711

16241712
// Third pass: transform helper functions with early returns to use __returnValue pattern
1625-
transformHelperFunctionEarlyReturns(ast);
1713+
transformHelperFunctionEarlyReturns(ast, uniformCallbackNames);
16261714

16271715
// Fourth pass: transform if/for statements in post-order using recursive traversal
16281716
const postOrderControlFlowTransform = {
1717+
CallExpression(node, state, c) {
1718+
if (nodeIsUniform(node)) { return; }
1719+
if (node.callee) c(node.callee, state);
1720+
for (const arg of node.arguments) c(arg, state);
1721+
},
1722+
FunctionDeclaration(node, state, c) {
1723+
if (state.uniformCallbackNames?.has(node.id?.name)) return;
1724+
if (node.body) c(node.body, state);
1725+
},
1726+
VariableDeclarator(node, state, c) {
1727+
if (
1728+
state.uniformCallbackNames?.has(node.id?.name) &&
1729+
(node.init?.type === 'FunctionExpression' || node.init?.type === 'ArrowFunctionExpression')
1730+
) {
1731+
return;
1732+
}
1733+
if (node.init) c(node.init, state);
1734+
},
16291735
IfStatement(node, state, c) {
16301736
state.inControlFlow++;
16311737
// First recursively process children
@@ -1662,7 +1768,7 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) {
16621768
delete node.argument;
16631769
}
16641770
};
1665-
recursive(ast, { varyings: {}, inControlFlow: 0 }, postOrderControlFlowTransform);
1771+
recursive(ast, { varyings: {}, inControlFlow: 0, uniformCallbackNames }, postOrderControlFlowTransform);
16661772
const transpiledSource = escodegen.generate(ast);
16671773
const scopeKeys = Object.keys(scope);
16681774
const match = /\(?\s*(?:function)?\s*\w*\s*\(([^)]*)\)\s*(?:=>)?\s*{((?:.|\n)*)}\s*;?\s*\)?/

0 commit comments

Comments
 (0)