Skip to content

Commit edc606f

Browse files
committed
chore: Use public API in test utilities (extractSnippet, ...)
1 parent c23c36b commit edc606f

6 files changed

Lines changed: 130 additions & 132 deletions

File tree

packages/typegpu/src/tgsl/shaderGenerator_members.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ export { UnknownData } from '../data/dataTypes.ts';
22

33
// types
44
export type { ResolutionCtx } from '../types.ts';
5+
export type { Snippet } from '../data/snippet.ts';

packages/typegpu/src/tgsl/wgslGenerator.ts

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,75 @@ ${this.ctx.pre}}`;
910910
return snip(stitch`${this.ctx.resolve(schema).value}(${args})`, schema, 'runtime');
911911
}
912912

913+
public _return(statement: tinyest.Return): string {
914+
const returnNode = statement[1];
915+
916+
if (returnNode !== undefined) {
917+
const expectedReturnType = this.ctx.topFunctionReturnType;
918+
let returnSnippet = expectedReturnType
919+
? this._typedExpression(returnNode, expectedReturnType)
920+
: this._expression(returnNode);
921+
922+
if (returnSnippet.value instanceof RefOperator) {
923+
throw new WgslTypeError(
924+
stitch`Cannot return references, returning '${returnSnippet.value.snippet}'`,
925+
);
926+
}
927+
928+
// Arguments cannot be returned from functions without copying. A simple example why is:
929+
// const identity = (x) => {
930+
// 'use gpu';
931+
// return x;
932+
// };
933+
//
934+
// const foo = (arg: d.v3f) => {
935+
// 'use gpu';
936+
// const marg = identity(arg);
937+
// marg.x = 1; // 'marg's origin would be 'runtime', so we wouldn't be able to track this misuse.
938+
// };
939+
if (
940+
returnSnippet.origin === 'argument' &&
941+
!wgsl.isNaturallyEphemeral(returnSnippet.dataType) &&
942+
// Only restricting this use in non-entry functions, as the function
943+
// is giving up ownership of all references anyway.
944+
this.ctx.topFunctionScope?.functionType === 'normal'
945+
) {
946+
throw new WgslTypeError(
947+
stitch`Cannot return references to arguments, returning '${returnSnippet}'. Copy the argument before returning it.`,
948+
);
949+
}
950+
951+
if (
952+
!expectedReturnType &&
953+
!isEphemeralSnippet(returnSnippet) &&
954+
returnSnippet.origin !== 'this-function'
955+
) {
956+
const str = this.ctx.resolve(returnSnippet.value, returnSnippet.dataType).value;
957+
const typeStr = this.ctx.resolve(unptr(returnSnippet.dataType)).value;
958+
throw new WgslTypeError(
959+
`'return ${str};' is invalid, cannot return references.
960+
-----
961+
Try 'return ${typeStr}(${str});' instead.
962+
-----`,
963+
);
964+
}
965+
966+
returnSnippet = tryConvertSnippet(
967+
this.ctx,
968+
returnSnippet,
969+
unptr(returnSnippet.dataType) as wgsl.AnyWgslData,
970+
false,
971+
);
972+
973+
invariant(returnSnippet.dataType !== UnknownData, 'Return type should be known');
974+
975+
this.ctx.reportReturnType(returnSnippet.dataType);
976+
return stitch`${this.ctx.pre}return ${returnSnippet};`;
977+
}
978+
979+
return `${this.ctx.pre}return;`;
980+
}
981+
913982
public _statement(statement: tinyest.Statement): string {
914983
if (typeof statement === 'string') {
915984
const id = this._identifier(statement);
@@ -923,72 +992,7 @@ ${this.ctx.pre}}`;
923992
}
924993

925994
if (statement[0] === NODE.return) {
926-
const returnNode = statement[1];
927-
928-
if (returnNode !== undefined) {
929-
const expectedReturnType = this.ctx.topFunctionReturnType;
930-
let returnSnippet = expectedReturnType
931-
? this._typedExpression(returnNode, expectedReturnType)
932-
: this._expression(returnNode);
933-
934-
if (returnSnippet.value instanceof RefOperator) {
935-
throw new WgslTypeError(
936-
stitch`Cannot return references, returning '${returnSnippet.value.snippet}'`,
937-
);
938-
}
939-
940-
// Arguments cannot be returned from functions without copying. A simple example why is:
941-
// const identity = (x) => {
942-
// 'use gpu';
943-
// return x;
944-
// };
945-
//
946-
// const foo = (arg: d.v3f) => {
947-
// 'use gpu';
948-
// const marg = identity(arg);
949-
// marg.x = 1; // 'marg's origin would be 'runtime', so we wouldn't be able to track this misuse.
950-
// };
951-
if (
952-
returnSnippet.origin === 'argument' &&
953-
!wgsl.isNaturallyEphemeral(returnSnippet.dataType) &&
954-
// Only restricting this use in non-entry functions, as the function
955-
// is giving up ownership of all references anyway.
956-
this.ctx.topFunctionScope?.functionType === 'normal'
957-
) {
958-
throw new WgslTypeError(
959-
stitch`Cannot return references to arguments, returning '${returnSnippet}'. Copy the argument before returning it.`,
960-
);
961-
}
962-
963-
if (
964-
!expectedReturnType &&
965-
!isEphemeralSnippet(returnSnippet) &&
966-
returnSnippet.origin !== 'this-function'
967-
) {
968-
const str = this.ctx.resolve(returnSnippet.value, returnSnippet.dataType).value;
969-
const typeStr = this.ctx.resolve(unptr(returnSnippet.dataType)).value;
970-
throw new WgslTypeError(
971-
`'return ${str};' is invalid, cannot return references.
972-
-----
973-
Try 'return ${typeStr}(${str});' instead.
974-
-----`,
975-
);
976-
}
977-
978-
returnSnippet = tryConvertSnippet(
979-
this.ctx,
980-
returnSnippet,
981-
unptr(returnSnippet.dataType) as wgsl.AnyWgslData,
982-
false,
983-
);
984-
985-
invariant(returnSnippet.dataType !== UnknownData, 'Return type should be known');
986-
987-
this.ctx.reportReturnType(returnSnippet.dataType);
988-
return stitch`${this.ctx.pre}return ${returnSnippet};`;
989-
}
990-
991-
return `${this.ctx.pre}return;`;
995+
return this._return(statement);
992996
}
993997

994998
if (statement[0] === NODE.if) {

packages/typegpu/tests/std/numeric/add.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,22 @@ describe('add', () => {
104104
it('infers types when adding constants', () => {
105105
const int_int = () => {
106106
'use gpu';
107-
1 + 2;
107+
return 1 + 2;
108108
};
109109

110110
const float_float = () => {
111111
'use gpu';
112-
1.1 + 2.3;
112+
return 1.1 + 2.3;
113113
};
114114

115115
const int_float = () => {
116116
'use gpu';
117-
1.1 + 2;
117+
return 1.1 + 2;
118118
};
119119

120120
const float_int = () => {
121121
'use gpu';
122-
1 + 2.3;
122+
return 1 + 2.3;
123123
};
124124

125125
expectDataTypeOf(int_int).toBe(abstractInt);

packages/typegpu/tests/tgsl/memberAccess.test.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ describe('Member Access', () => {
1212
it('should access member properties of literals', () => {
1313
expectSnippetOf(() => {
1414
'use gpu';
15-
Boid().pos;
15+
return Boid().pos;
1616
}).toStrictEqual(snip('Boid().pos', d.vec3f, 'runtime'));
1717

1818
expectSnippetOf(() => {
1919
'use gpu';
20-
Boid().pos.xyz;
20+
return Boid().pos.xyz;
2121
}).toStrictEqual(snip('Boid().pos.xyz', d.vec3f, 'runtime'));
2222
});
2323

@@ -26,12 +26,12 @@ describe('Member Access', () => {
2626

2727
expectSnippetOf(() => {
2828
'use gpu';
29-
boid.pos;
29+
return boid.pos;
3030
}).toStrictEqual(snip(d.vec3f(1, 2, 3), d.vec3f, 'constant'));
3131

3232
expectSnippetOf(() => {
3333
'use gpu';
34-
boid.pos.zyx;
34+
return boid.pos.zyx;
3535
}).toStrictEqual(snip(d.vec3f(3, 2, 1), d.vec3f, 'constant'));
3636
});
3737

@@ -40,12 +40,12 @@ describe('Member Access', () => {
4040

4141
expectSnippetOf(() => {
4242
'use gpu';
43-
boidVar.$.pos;
43+
return boidVar.$.pos;
4444
}).toStrictEqual(snip('boidVar.pos', d.vec3f, 'private'));
4545

4646
expectSnippetOf(() => {
4747
'use gpu';
48-
boidVar.$.pos.xyz;
48+
return boidVar.$.pos.xyz;
4949
}).toStrictEqual(snip('boidVar.pos.xyz', d.vec3f, 'runtime')); // < swizzles are new objects
5050
});
5151

@@ -56,7 +56,7 @@ describe('Member Access', () => {
5656
const boid = Boid();
5757
// Taking a reference that is local to this function
5858
const boidRef = boid;
59-
boidRef.pos;
59+
return boidRef.pos;
6060
}).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'this-function'));
6161
});
6262

@@ -68,14 +68,14 @@ describe('Member Access', () => {
6868
'use gpu';
6969
// Taking a reference to a storage variable
7070
const boidRef = boidReadonly.$;
71-
boidRef.pos;
71+
return boidRef.pos;
7272
}).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'readonly'));
7373

7474
expectSnippetOf(() => {
7575
'use gpu';
7676
// Taking a reference to a storage variable
7777
const boidRef = boidMutable.$;
78-
boidRef.pos;
78+
return boidRef.pos;
7979
}).toStrictEqual(snip('(*boidRef).pos', d.vec3f, 'mutable'));
8080
});
8181
});

packages/typegpu/tests/tgsl/wgslGenerator.test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import { CodegenState } from '../../src/types.ts';
1616
import { it } from 'typegpu-testing-utility';
1717
import { ArrayExpression } from '../../src/tgsl/generationHelpers.ts';
1818
import { extractSnippetFromFn } from '../utils/parseResolved.ts';
19-
import { UnknownData } from '../../src/tgsl/shaderGenerator_members.ts';
2019

2120
const { NodeTypeCatalog: NODE } = tinyest;
2221

@@ -1086,7 +1085,7 @@ describe('wgslGenerator', () => {
10861085
it('creates intermediate representation for array expression', () => {
10871086
const testFn = () => {
10881087
'use gpu';
1089-
[d.u32(1), 8, 8, 2];
1088+
return [d.u32(1), 8, 8, 2];
10901089
};
10911090

10921091
const snippet = extractSnippetFromFn(testFn);
Lines changed: 45 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,63 @@
11
import type * as tinyest from 'tinyest';
2+
import { NodeTypeCatalog as NODE } from 'tinyest';
23
import { type Assertion, expect } from 'vitest';
3-
import type { BaseData } from '../../src/data/index.ts';
4-
import type { UnknownData } from '../../src/data/dataTypes.ts';
5-
import { ResolutionCtxImpl } from '../../src/resolutionCtx.ts';
6-
import { provideCtx } from '../../src/execMode.ts';
7-
import { getMetaData } from '../../src/shared/meta.ts';
8-
import wgslGenerator from '../../src/tgsl/wgslGenerator.ts';
9-
import { namespace } from '../../src/core/resolve/namespace.ts';
10-
import type { Snippet } from '../../src/data/snippet.ts';
11-
import { $internal } from '../../src/shared/symbols.ts';
12-
import { CodegenState } from '../../src/types.ts';
4+
import tgpu, { d, ShaderGenerator, WgslGenerator } from 'typegpu';
135

14-
export function extractSnippetFromFn(cb: () => unknown): Snippet {
15-
const ctx = new ResolutionCtxImpl({
16-
namespace: namespace({ names: 'strict' }),
17-
});
6+
type Snippet = ShaderGenerator.Snippet;
7+
type UnknownData = ShaderGenerator.UnknownData;
188

19-
return provideCtx(ctx, () => {
20-
let pushedFnScope = false;
21-
try {
22-
const meta = getMetaData(cb);
9+
class ExtractingGenerator extends WgslGenerator {
10+
#fnDepth: number;
2311

24-
if (!meta || !meta.ast) {
25-
throw new Error('No metadata found for the function');
26-
}
12+
returnedSnippet: Snippet | undefined;
2713

28-
ctx.pushMode(new CodegenState());
29-
ctx[$internal].itemStateStack.pushItem();
30-
ctx[$internal].itemStateStack.pushFunctionScope(
31-
'normal',
32-
[],
33-
{},
34-
undefined,
35-
(meta.externals as () => Record<string, string>)() ?? {},
36-
);
37-
ctx.pushBlockScope();
38-
pushedFnScope = true;
14+
constructor() {
15+
super();
16+
this.#fnDepth = 0;
17+
}
3918

40-
// Extracting the last expression from the block
41-
const statements = meta.ast.body[1] ?? [];
42-
if (statements.length === 0) {
43-
throw new Error(`Expected at least one expression, got ${statements.length}`);
44-
}
19+
public functionDefinition(body: tinyest.Block): string {
20+
this.#fnDepth++;
21+
try {
22+
return super.functionDefinition(body);
23+
} finally {
24+
this.#fnDepth--;
25+
}
26+
}
4527

46-
wgslGenerator.initGenerator(ctx);
47-
// Prewarming statements
48-
for (const statement of statements) {
49-
wgslGenerator._statement(statement);
28+
public _return(statement: tinyest.Return): string {
29+
if (this.#fnDepth === 1) {
30+
if (this.returnedSnippet) {
31+
throw new Error('Cannot inspect multiple return values');
5032
}
51-
return wgslGenerator._expression(statements[statements.length - 1] as tinyest.Expression);
52-
} finally {
53-
if (pushedFnScope) {
54-
ctx.popBlockScope();
55-
ctx[$internal].itemStateStack.pop('functionScope');
56-
ctx[$internal].itemStateStack.pop('item');
33+
if (!statement[1]) {
34+
throw new Error('Cannot inspect if nothing is returned');
5735
}
58-
ctx.popMode('codegen');
36+
this.returnedSnippet = this._expression(statement[1]);
37+
return super._return([NODE.return]);
5938
}
60-
});
39+
40+
// Proceed as usual
41+
return super._return(statement);
42+
}
43+
}
44+
45+
export function extractSnippetFromFn(cb: () => unknown): Snippet {
46+
const generator = new ExtractingGenerator();
47+
48+
tgpu.resolve([cb], { unstable_shaderGenerator: generator });
49+
50+
if (!generator.returnedSnippet) {
51+
throw new Error('Something must be returned to be inspected');
52+
}
53+
54+
return generator.returnedSnippet;
6155
}
6256

6357
export function expectSnippetOf(cb: () => unknown): Assertion<Snippet> {
6458
return expect(extractSnippetFromFn(cb));
6559
}
6660

67-
export function expectDataTypeOf(cb: () => unknown): Assertion<BaseData | UnknownData> {
68-
return expect<BaseData | UnknownData>(extractSnippetFromFn(cb).dataType);
61+
export function expectDataTypeOf(cb: () => unknown): Assertion<d.BaseData | UnknownData> {
62+
return expect<d.BaseData | UnknownData>(extractSnippetFromFn(cb).dataType);
6963
}

0 commit comments

Comments
 (0)