diff --git a/packages/typegpu/src/tgsl/shaderGenerator_members.ts b/packages/typegpu/src/tgsl/shaderGenerator_members.ts index a33ac92b24..1eb11ae0c9 100644 --- a/packages/typegpu/src/tgsl/shaderGenerator_members.ts +++ b/packages/typegpu/src/tgsl/shaderGenerator_members.ts @@ -2,3 +2,5 @@ export { UnknownData } from '../data/dataTypes.ts'; // types export type { ResolutionCtx } from '../types.ts'; +export type { Snippet } from '../data/snippet.ts'; +export type { Origin } from '../data/snippet.ts'; diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index 7c7c06d7c8..fc413e08fe 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -910,6 +910,75 @@ ${this.ctx.pre}}`; return snip(stitch`${this.ctx.resolve(schema).value}(${args})`, schema, 'runtime'); } + public _return(statement: tinyest.Return): string { + const returnNode = statement[1]; + + if (returnNode !== undefined) { + const expectedReturnType = this.ctx.topFunctionReturnType; + let returnSnippet = expectedReturnType + ? this._typedExpression(returnNode, expectedReturnType) + : this._expression(returnNode); + + if (returnSnippet.value instanceof RefOperator) { + throw new WgslTypeError( + stitch`Cannot return references, returning '${returnSnippet.value.snippet}'`, + ); + } + + // Arguments cannot be returned from functions without copying. A simple example why is: + // const identity = (x) => { + // 'use gpu'; + // return x; + // }; + // + // const foo = (arg: d.v3f) => { + // 'use gpu'; + // const marg = identity(arg); + // marg.x = 1; // 'marg's origin would be 'runtime', so we wouldn't be able to track this misuse. + // }; + if ( + returnSnippet.origin === 'argument' && + !wgsl.isNaturallyEphemeral(returnSnippet.dataType) && + // Only restricting this use in non-entry functions, as the function + // is giving up ownership of all references anyway. + this.ctx.topFunctionScope?.functionType === 'normal' + ) { + throw new WgslTypeError( + stitch`Cannot return references to arguments, returning '${returnSnippet}'. Copy the argument before returning it.`, + ); + } + + if ( + !expectedReturnType && + !isEphemeralSnippet(returnSnippet) && + returnSnippet.origin !== 'this-function' + ) { + const str = this.ctx.resolve(returnSnippet.value, returnSnippet.dataType).value; + const typeStr = this.ctx.resolve(unptr(returnSnippet.dataType)).value; + throw new WgslTypeError( + `'return ${str};' is invalid, cannot return references. +----- +Try 'return ${typeStr}(${str});' instead. +-----`, + ); + } + + returnSnippet = tryConvertSnippet( + this.ctx, + returnSnippet, + unptr(returnSnippet.dataType) as wgsl.AnyWgslData, + false, + ); + + invariant(returnSnippet.dataType !== UnknownData, 'Return type should be known'); + + this.ctx.reportReturnType(returnSnippet.dataType); + return stitch`${this.ctx.pre}return ${returnSnippet};`; + } + + return `${this.ctx.pre}return;`; + } + public _statement(statement: tinyest.Statement): string { if (typeof statement === 'string') { const id = this._identifier(statement); @@ -923,72 +992,7 @@ ${this.ctx.pre}}`; } if (statement[0] === NODE.return) { - const returnNode = statement[1]; - - if (returnNode !== undefined) { - const expectedReturnType = this.ctx.topFunctionReturnType; - let returnSnippet = expectedReturnType - ? this._typedExpression(returnNode, expectedReturnType) - : this._expression(returnNode); - - if (returnSnippet.value instanceof RefOperator) { - throw new WgslTypeError( - stitch`Cannot return references, returning '${returnSnippet.value.snippet}'`, - ); - } - - // Arguments cannot be returned from functions without copying. A simple example why is: - // const identity = (x) => { - // 'use gpu'; - // return x; - // }; - // - // const foo = (arg: d.v3f) => { - // 'use gpu'; - // const marg = identity(arg); - // marg.x = 1; // 'marg's origin would be 'runtime', so we wouldn't be able to track this misuse. - // }; - if ( - returnSnippet.origin === 'argument' && - !wgsl.isNaturallyEphemeral(returnSnippet.dataType) && - // Only restricting this use in non-entry functions, as the function - // is giving up ownership of all references anyway. - this.ctx.topFunctionScope?.functionType === 'normal' - ) { - throw new WgslTypeError( - stitch`Cannot return references to arguments, returning '${returnSnippet}'. Copy the argument before returning it.`, - ); - } - - if ( - !expectedReturnType && - !isEphemeralSnippet(returnSnippet) && - returnSnippet.origin !== 'this-function' - ) { - const str = this.ctx.resolve(returnSnippet.value, returnSnippet.dataType).value; - const typeStr = this.ctx.resolve(unptr(returnSnippet.dataType)).value; - throw new WgslTypeError( - `'return ${str};' is invalid, cannot return references. ------ -Try 'return ${typeStr}(${str});' instead. ------`, - ); - } - - returnSnippet = tryConvertSnippet( - this.ctx, - returnSnippet, - unptr(returnSnippet.dataType) as wgsl.AnyWgslData, - false, - ); - - invariant(returnSnippet.dataType !== UnknownData, 'Return type should be known'); - - this.ctx.reportReturnType(returnSnippet.dataType); - return stitch`${this.ctx.pre}return ${returnSnippet};`; - } - - return `${this.ctx.pre}return;`; + return this._return(statement); } if (statement[0] === NODE.if) { diff --git a/packages/typegpu/tests/std/numeric/add.test.ts b/packages/typegpu/tests/std/numeric/add.test.ts index cbd5a2184b..d8f64b17ea 100644 --- a/packages/typegpu/tests/std/numeric/add.test.ts +++ b/packages/typegpu/tests/std/numeric/add.test.ts @@ -104,22 +104,22 @@ describe('add', () => { it('infers types when adding constants', () => { const int_int = () => { 'use gpu'; - 1 + 2; + return 1 + 2; }; const float_float = () => { 'use gpu'; - 1.1 + 2.3; + return 1.1 + 2.3; }; const int_float = () => { 'use gpu'; - 1.1 + 2; + return 1.1 + 2; }; const float_int = () => { 'use gpu'; - 1 + 2.3; + return 1 + 2.3; }; expectDataTypeOf(int_int).toBe(abstractInt); diff --git a/packages/typegpu/tests/tgsl/memberAccess.test.ts b/packages/typegpu/tests/tgsl/memberAccess.test.ts index dfa33a9f34..41483b66a3 100644 --- a/packages/typegpu/tests/tgsl/memberAccess.test.ts +++ b/packages/typegpu/tests/tgsl/memberAccess.test.ts @@ -1,7 +1,6 @@ import { describe } from 'vitest'; import { it } from 'typegpu-testing-utility'; import { expectSnippetOf } from '../utils/parseResolved.ts'; -import { snip } from '../../src/data/snippet.ts'; import tgpu, { d } from '../../src/index.js'; describe('Member Access', () => { @@ -12,13 +11,13 @@ describe('Member Access', () => { it('should access member properties of literals', () => { expectSnippetOf(() => { 'use gpu'; - Boid().pos; - }).toStrictEqual(snip('Boid().pos', d.vec3f, 'runtime')); + return Boid().pos; + }).toStrictEqual(['Boid().pos', d.vec3f, 'runtime']); expectSnippetOf(() => { 'use gpu'; - Boid().pos.xyz; - }).toStrictEqual(snip('Boid().pos.xyz', d.vec3f, 'runtime')); + return Boid().pos.xyz; + }).toStrictEqual(['Boid().pos.xyz', d.vec3f, 'runtime']); }); it('should access member properties of externals', () => { @@ -26,13 +25,13 @@ describe('Member Access', () => { expectSnippetOf(() => { 'use gpu'; - boid.pos; - }).toStrictEqual(snip(d.vec3f(1, 2, 3), d.vec3f, 'constant')); + return boid.pos; + }).toStrictEqual([d.vec3f(1, 2, 3), d.vec3f, 'constant']); expectSnippetOf(() => { 'use gpu'; - boid.pos.zyx; - }).toStrictEqual(snip(d.vec3f(3, 2, 1), d.vec3f, 'constant')); + return boid.pos.zyx; + }).toStrictEqual([d.vec3f(3, 2, 1), d.vec3f, 'constant']); }); it('should access member properties of variables', () => { @@ -40,13 +39,13 @@ describe('Member Access', () => { expectSnippetOf(() => { 'use gpu'; - boidVar.$.pos; - }).toStrictEqual(snip('boidVar.pos', d.vec3f, 'private')); + return boidVar.$.pos; + }).toStrictEqual(['boidVar.pos', d.vec3f, 'private']); expectSnippetOf(() => { 'use gpu'; - boidVar.$.pos.xyz; - }).toStrictEqual(snip('boidVar.pos.xyz', d.vec3f, 'runtime')); // < swizzles are new objects + return boidVar.$.pos.xyz; + }).toStrictEqual(['boidVar.pos.xyz', d.vec3f, 'runtime']); // < swizzles are new objects }); it('derefs access to local variables with proper address space', () => { @@ -56,8 +55,8 @@ describe('Member Access', () => { const boid = Boid(); // Taking a reference that is local to this function const boidRef = boid; - boidRef.pos; - }).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'this-function')); + return boidRef.pos; + }).toStrictEqual(['(*boidRef).pos', d.vec3f, 'this-function']); }); it('derefs access to storage with proper address space', ({ root }) => { @@ -68,14 +67,14 @@ describe('Member Access', () => { 'use gpu'; // Taking a reference to a storage variable const boidRef = boidReadonly.$; - boidRef.pos; - }).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'readonly')); + return boidRef.pos; + }).toStrictEqual(['(*boidRef).pos', d.vec3f, 'readonly']); expectSnippetOf(() => { 'use gpu'; // Taking a reference to a storage variable const boidRef = boidMutable.$; - boidRef.pos; - }).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'mutable')); + return boidRef.pos; + }).toStrictEqual(['(*boidRef).pos', d.vec3f, 'mutable']); }); }); diff --git a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts index 62c96fe428..64189675d7 100644 --- a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts +++ b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts @@ -16,7 +16,6 @@ import { CodegenState } from '../../src/types.ts'; import { it } from 'typegpu-testing-utility'; import { ArrayExpression } from '../../src/tgsl/generationHelpers.ts'; import { extractSnippetFromFn } from '../utils/parseResolved.ts'; -import { UnknownData } from '../../src/tgsl/shaderGenerator_members.ts'; const { NodeTypeCatalog: NODE } = tinyest; @@ -1086,7 +1085,7 @@ describe('wgslGenerator', () => { it('creates intermediate representation for array expression', () => { const testFn = () => { 'use gpu'; - [d.u32(1), 8, 8, 2]; + return [d.u32(1), 8, 8, 2]; }; const snippet = extractSnippetFromFn(testFn); diff --git a/packages/typegpu/tests/utils/parseResolved.ts b/packages/typegpu/tests/utils/parseResolved.ts index b798ff9f8c..7576fedc31 100644 --- a/packages/typegpu/tests/utils/parseResolved.ts +++ b/packages/typegpu/tests/utils/parseResolved.ts @@ -1,69 +1,67 @@ import type * as tinyest from 'tinyest'; +import { NodeTypeCatalog as NODE } from 'tinyest'; import { type Assertion, expect } from 'vitest'; -import type { BaseData } from '../../src/data/index.ts'; -import type { UnknownData } from '../../src/data/dataTypes.ts'; -import { ResolutionCtxImpl } from '../../src/resolutionCtx.ts'; -import { provideCtx } from '../../src/execMode.ts'; -import { getMetaData } from '../../src/shared/meta.ts'; -import wgslGenerator from '../../src/tgsl/wgslGenerator.ts'; -import { namespace } from '../../src/core/resolve/namespace.ts'; -import type { Snippet } from '../../src/data/snippet.ts'; -import { $internal } from '../../src/shared/symbols.ts'; -import { CodegenState } from '../../src/types.ts'; +import tgpu, { d, ShaderGenerator, WgslGenerator } from 'typegpu'; -export function extractSnippetFromFn(cb: () => unknown): Snippet { - const ctx = new ResolutionCtxImpl({ - namespace: namespace({ names: 'strict' }), - }); +type Snippet = ShaderGenerator.Snippet; +type UnknownData = ShaderGenerator.UnknownData; +type Origin = ShaderGenerator.Origin; - return provideCtx(ctx, () => { - let pushedFnScope = false; - try { - const meta = getMetaData(cb); +class ExtractingGenerator extends WgslGenerator { + #fnDepth: number; - if (!meta || !meta.ast) { - throw new Error('No metadata found for the function'); - } + returnedSnippet: Snippet | undefined; - ctx.pushMode(new CodegenState()); - ctx[$internal].itemStateStack.pushItem(); - ctx[$internal].itemStateStack.pushFunctionScope( - 'normal', - [], - {}, - undefined, - (meta.externals as () => Record)() ?? {}, - ); - ctx.pushBlockScope(); - pushedFnScope = true; + constructor() { + super(); + this.#fnDepth = 0; + } - // Extracting the last expression from the block - const statements = meta.ast.body[1] ?? []; - if (statements.length === 0) { - throw new Error(`Expected at least one expression, got ${statements.length}`); - } + public functionDefinition(body: tinyest.Block): string { + this.#fnDepth++; + try { + return super.functionDefinition(body); + } finally { + this.#fnDepth--; + } + } - wgslGenerator.initGenerator(ctx); - // Prewarming statements - for (const statement of statements) { - wgslGenerator._statement(statement); + public _return(statement: tinyest.Return): string { + if (this.#fnDepth === 1) { + if (this.returnedSnippet) { + throw new Error('Cannot inspect multiple return values'); } - return wgslGenerator._expression(statements[statements.length - 1] as tinyest.Expression); - } finally { - if (pushedFnScope) { - ctx.popBlockScope(); - ctx[$internal].itemStateStack.pop('functionScope'); - ctx[$internal].itemStateStack.pop('item'); + if (!statement[1]) { + throw new Error('Cannot inspect if nothing is returned'); } - ctx.popMode('codegen'); + this.returnedSnippet = this._expression(statement[1]); + return super._return([NODE.return]); } - }); + + // Proceed as usual + return super._return(statement); + } +} + +export function extractSnippetFromFn(cb: () => unknown): Snippet { + const generator = new ExtractingGenerator(); + + tgpu.resolve([cb], { unstable_shaderGenerator: generator }); + + if (!generator.returnedSnippet) { + throw new Error('Something must be returned to be inspected'); + } + + return generator.returnedSnippet; } -export function expectSnippetOf(cb: () => unknown): Assertion { - return expect(extractSnippetFromFn(cb)); +export function expectSnippetOf( + cb: () => unknown, +): Assertion<[unknown, d.BaseData | UnknownData, Origin]> { + const snippet = extractSnippetFromFn(cb); + return expect([snippet.value, snippet.dataType, snippet.origin]); } -export function expectDataTypeOf(cb: () => unknown): Assertion { - return expect(extractSnippetFromFn(cb).dataType); +export function expectDataTypeOf(cb: () => unknown): Assertion { + return expect(extractSnippetFromFn(cb).dataType); }