Skip to content

Commit bd40881

Browse files
committed
feat: Limited support for the runtime ternary operator
1 parent ebb4389 commit bd40881

16 files changed

Lines changed: 350 additions & 73 deletions

File tree

apps/typegpu-docs/tests/individual-example-tests/game-of-life.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ describe('game of life example', () => {
5353
5454
fn wrappedCallback(x: u32, y: u32, _arg_2: u32) {
5555
randSeed2(((vec2f(f32(x), f32(y)) / f32(gameSizeUniform)) * timeUniform));
56-
textureStore(next, vec2u(x, y), vec4u(u32(select(0, 1, (randFloat01() > 0.5f))), 0u, 0u, 0u));
56+
textureStore(next, vec2u(x, y), vec4u(u32(select(0i, 1i, (randFloat01() > 0.5f))), 0u, 0u, 0u));
5757
}
5858
5959
@compute @workgroup_size(16, 16, 1) fn mainCompute(@builtin(global_invocation_id) id: vec3u) {
@@ -146,7 +146,7 @@ describe('game of life example', () => {
146146
let current_1 = readTile(lx, ly);
147147
let neighbors = countNeighborsInTile(lx, ly);
148148
let nextAlive = golNextState((current_1 != 0u), neighbors);
149-
textureStore(next, gid.xy, vec4u(u32(select(0, 1, nextAlive)), 0u, 0u, 0u));
149+
textureStore(next, gid.xy, vec4u(u32(select(0i, 1i, nextAlive)), 0u, 0u, 0u));
150150
}
151151
152152
@group(0) @binding(0) var<uniform> gameSizeUniform: u32;
@@ -199,7 +199,7 @@ describe('game of life example', () => {
199199
let current_1 = readTile(lx, ly);
200200
let neighbors = countNeighborsInTile(lx, ly);
201201
let nextAlive = golNextState((current_1 != 0u), neighbors);
202-
textureStore(next, gid.xy, vec4u(u32(select(0, 1, nextAlive)), 0u, 0u, 0u));
202+
textureStore(next, gid.xy, vec4u(u32(select(0i, 1i, nextAlive)), 0u, 0u, 0u));
203203
}
204204
205205
struct fullScreenTriangle_Output {

apps/typegpu-docs/tests/individual-example-tests/jelly-switch.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ describe('jelly switch example', () => {
390390
let ndc = vec2f(((_arg_0.uv.x * 2f) - 1f), -(((_arg_0.uv.y * 2f) - 1f)));
391391
let ray = getRay(ndc);
392392
let color = rayMarch(ray.origin, ray.direction, _arg_0.uv);
393-
let exposure = select(1.5, 2., (darkModeUniform == 1u));
393+
let exposure = select(1.5f, 2f, (darkModeUniform == 1u));
394394
return vec4f(tanh((color.rgb * exposure)), 1f);
395395
}
396396

apps/typegpu-docs/tests/individual-example-tests/probability.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ describe('probability distribution plot example', () => {
251251
let face = u32((sample() * 6f));
252252
let axis = (face % 3u);
253253
var result = vec3f();
254-
result[axis] = f32(select(0, 1, (face > 2u)));
254+
result[axis] = f32(select(0i, 1i, (face > 2u)));
255255
result[((axis + 1u) % 3u)] = sample();
256256
result[((axis + 2u) % 3u)] = sample();
257257
return result;

apps/typegpu-docs/tests/individual-example-tests/ripple-cube.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ describe('ripple-cube example', () => {
262262
const cellSize = 0.0047169811320754715;
263263
let p = ((vec3f(f32(x), f32(y), f32(z)) + 0.5f) * cellSize);
264264
let r = (timeUniform * 0.15f);
265-
let iterCount = select(5, 11, (extendedRippleUniform == 1u));
265+
let iterCount = select(5i, 11i, (extendedRippleUniform == 1u));
266266
var shellD = 1e+10f;
267267
for (var ix = 0; (ix < iterCount); ix++) {
268268
for (var iy = 0; (iy < iterCount); iy++) {

packages/typegpu/src/core/function/createCallableSchema.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
import { type MapValueToSnippet, type ResolvedSnippet, snip } from '../../data/snippet.ts';
1+
import {
2+
type MapValueToSnippet,
3+
noSideEffects,
4+
type ResolvedSnippet,
5+
snip,
6+
type Snippet,
7+
} from '../../data/snippet.ts';
28
import { type BaseData, isPtr } from '../../data/wgslTypes.ts';
39
import { setName } from '../../shared/meta.ts';
410
import { $gpuCallable } from '../../shared/symbols.ts';
@@ -51,10 +57,11 @@ export function callableSchema<T extends AnyFn>(options: CallableSchemaOptions<T
5157
return tryConvertSnippet(ctx, s, argType, false);
5258
}) as MapValueToSnippet<Parameters<T>>;
5359

60+
let result: Snippet;
5461
if (converted.every((s) => isKnownAtComptime(s))) {
5562
ctx.pushMode(new NormalState());
5663
try {
57-
return snip(
64+
result = snip(
5865
options.normalImpl(...(converted.map((s) => s.value) as never[])),
5966
options.schema(),
6067
// Functions give up ownership of their return value
@@ -63,9 +70,14 @@ export function callableSchema<T extends AnyFn>(options: CallableSchemaOptions<T
6370
} finally {
6471
ctx.popMode('normal');
6572
}
73+
} else {
74+
result = options.codegenImpl(ctx, converted);
6675
}
6776

68-
return options.codegenImpl(ctx, converted);
77+
if (!args.some((a) => a.possibleSideEffects)) {
78+
return noSideEffects(result);
79+
}
80+
return result;
6981
},
7082
};
7183

packages/typegpu/src/core/slot/accessor.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { type AnyData, isData } from '../../data/dataTypes.ts';
22
import { schemaCallWrapper } from '../../data/schemaCallWrapper.ts';
3-
import { isSnippet, type ResolvedSnippet, snip } from '../../data/snippet.ts';
3+
import { isSnippet, noSideEffects, type ResolvedSnippet, snip } from '../../data/snippet.ts';
44
import type { BaseData } from '../../data/wgslTypes.ts';
55
import { getResolutionCtx, inCodegenMode } from '../../execMode.ts';
66
import { getName, hasTinyestMetadata, setName } from '../../shared/meta.ts';
@@ -119,7 +119,7 @@ abstract class AccessorBase<
119119
try {
120120
// Doing a deep copy each time so that we don't have to deal with refs
121121
const cloned = schemaCallWrapper(this.schema, value);
122-
return snip(cloned, this.schema, 'constant');
122+
return noSideEffects(snip(cloned, this.schema, 'constant'));
123123
} finally {
124124
ctx.popMode('normal');
125125
}

packages/typegpu/src/data/snippet.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export interface Snippet {
9292
*/
9393
readonly dataType: BaseData | UnknownData;
9494
readonly origin: Origin;
95+
readonly possibleSideEffects: boolean;
9596
}
9697

9798
export interface ResolvedSnippet extends Snippet {
@@ -105,11 +106,18 @@ class SnippetImpl implements Snippet {
105106
readonly value: unknown;
106107
readonly dataType: BaseData | UnknownData;
107108
readonly origin: Origin;
109+
readonly possibleSideEffects: boolean;
108110

109-
constructor(value: unknown, dataType: BaseData | UnknownData, origin: Origin) {
111+
constructor(
112+
value: unknown,
113+
dataType: BaseData | UnknownData,
114+
origin: Origin,
115+
possibleSideEffects: boolean,
116+
) {
110117
this.value = value;
111118
this.dataType = dataType;
112119
this.origin = origin;
120+
this.possibleSideEffects = possibleSideEffects;
113121
}
114122
}
115123

@@ -138,5 +146,21 @@ export function snip(
138146
// We don't care about attributes in snippet land, so we discard that information.
139147
undecorate(dataType as BaseData),
140148
origin,
149+
/* possibleSideEffects */ true,
141150
);
142151
}
152+
153+
export function withSideEffects(
154+
possibleSideEffects: boolean,
155+
snippet: ResolvedSnippet,
156+
): ResolvedSnippet;
157+
export function withSideEffects(possibleSideEffects: boolean, snippet: Snippet): Snippet;
158+
export function withSideEffects(possibleSideEffects: boolean, snippet: Snippet): Snippet {
159+
return new SnippetImpl(snippet.value, snippet.dataType, snippet.origin, possibleSideEffects);
160+
}
161+
162+
export function noSideEffects(snippet: ResolvedSnippet): ResolvedSnippet;
163+
export function noSideEffects(snippet: Snippet): Snippet;
164+
export function noSideEffects(snippet: Snippet): Snippet {
165+
return withSideEffects(/* possibleSideEffects */ false, snippet);
166+
}

packages/typegpu/src/std/boolean.ts

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
import { dualImpl } from '../core/function/dualImpl.ts';
22
import { stitch } from '../core/resolve/stitch.ts';
3-
import { bool, f32 } from '../data/numeric.ts';
3+
import { bool, f16, f32, i32, u32 } from '../data/numeric.ts';
44
import { isSnippetNumeric, snip } from '../data/snippet.ts';
5-
import { vec2b, vec3b, vec4b } from '../data/vector.ts';
5+
import {
6+
vec2b,
7+
vec2f,
8+
vec2h,
9+
vec2i,
10+
vec2u,
11+
vec3b,
12+
vec3f,
13+
vec3h,
14+
vec3i,
15+
vec3u,
16+
vec4b,
17+
vec4f,
18+
vec4h,
19+
vec4i,
20+
vec4u,
21+
} from '../data/vector.ts';
622
import { VectorOps } from '../data/vectorOps.ts';
723
import {
824
type AnyBooleanVecInstance,
@@ -374,6 +390,29 @@ function cpuSelect<T extends number | boolean | AnyVecInstance>(
374390
);
375391
}
376392

393+
export const validSelectBranchTypes: AnyWgslData[] = [
394+
f32,
395+
f16,
396+
i32,
397+
u32,
398+
bool,
399+
vec2f,
400+
vec3f,
401+
vec4f,
402+
vec2h,
403+
vec3h,
404+
vec4h,
405+
vec2i,
406+
vec3i,
407+
vec4i,
408+
vec2u,
409+
vec3u,
410+
vec4u,
411+
vec2b,
412+
vec3b,
413+
vec4b,
414+
];
415+
377416
/**
378417
* Returns `t` if `cond` is `true`, and `f` otherwise.
379418
* Component-wise if `cond` is a vector.
@@ -386,9 +425,27 @@ function cpuSelect<T extends number | boolean | AnyVecInstance>(
386425
export const select = dualImpl({
387426
name: 'select',
388427
signature: (f, t, cond) => {
389-
const [uf, ut] = unify([f, t]) ?? ([f, t] as const);
428+
const [uf, ut] = unify([f, t], validSelectBranchTypes) ?? ([f, t] as const);
390429
return { argTypes: [uf, ut, cond], returnType: uf };
391430
},
392431
normalImpl: cpuSelect,
393-
codegenImpl: (_ctx, [f, t, cond]) => stitch`select(${f}, ${t}, ${cond})`,
432+
codegenImpl: (ctx, [f, t, cond]) => {
433+
const result = stitch`select(${f}, ${t}, ${cond})`;
434+
if (
435+
!validSelectBranchTypes.includes(f.dataType as AnyWgslData) ||
436+
!validSelectBranchTypes.includes(t.dataType as AnyWgslData)
437+
) {
438+
throw new Error(
439+
`'${result}' is invalid, std.select requires both branches to be either scalars or vectors.`,
440+
);
441+
}
442+
if (f.dataType !== t.dataType) {
443+
const fStr = ctx.resolve(f.dataType);
444+
const tStr = ctx.resolve(t.dataType);
445+
throw new Error(
446+
`'${result}' is invalid, std.select requires both branches to be the same type, got [${fStr.value}, ${tStr.value}].`,
447+
);
448+
}
449+
return result;
450+
},
394451
});

packages/typegpu/src/tgsl/accessIndex.ts

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { stitch } from '../core/resolve/stitch.ts';
22
import { isDisarray, MatrixColumnsAccess } from '../data/dataTypes.ts';
33
import { derefSnippet } from '../data/ref.ts';
4-
import { type Origin, snip, type Snippet } from '../data/snippet.ts';
4+
import { snip, withSideEffects } from '../data/snippet.ts';
5+
import type { Origin, Snippet } from '../data/snippet.ts';
56
import { vec2f, vec3f, vec4f } from '../data/vector.ts';
67
import { type BaseData, isPtr, isVec, isWgslArray, isWgslStruct } from '../data/wgslTypes.ts';
78
import { isKnownAtComptime } from '../types.ts';
@@ -44,25 +45,31 @@ export function accessIndex(target: Snippet, indexArg: Snippet | number): Snippe
4445
return target.value.elements[index.value as number];
4546
}
4647

47-
return snip(
48-
isKnownAtComptime(target) && isKnownAtComptime(index)
49-
? // oxlint-disable-next-line typescript/no-explicit-any -- it's fine, it's there
50-
(target.value as any)[index.value as number]
51-
: stitch`${target}[${index}]`,
52-
elementType,
53-
/* origin */ origin,
48+
return withSideEffects(
49+
target.possibleSideEffects || index.possibleSideEffects,
50+
snip(
51+
isKnownAtComptime(target) && isKnownAtComptime(index)
52+
? // oxlint-disable-next-line typescript/no-explicit-any -- it's fine, it's there
53+
(target.value as any)[index.value as number]
54+
: stitch`${target}[${index}]`,
55+
elementType,
56+
/* origin */ origin,
57+
),
5458
);
5559
}
5660

5761
// vector
5862
if (isVec(target.dataType)) {
59-
return snip(
60-
isKnownAtComptime(target) && isKnownAtComptime(index)
61-
? // oxlint-disable-next-line typescript/no-explicit-any -- it's fine, it's there
62-
(target.value as any)[index.value as any]
63-
: stitch`${target}[${index}]`,
64-
target.dataType.primitive,
65-
/* origin */ target.origin,
63+
return withSideEffects(
64+
target.possibleSideEffects || index.possibleSideEffects,
65+
snip(
66+
isKnownAtComptime(target) && isKnownAtComptime(index)
67+
? // oxlint-disable-next-line typescript/no-explicit-any -- it's fine, it's there
68+
(target.value as any)[index.value as any]
69+
: stitch`${target}[${index}]`,
70+
target.dataType.primitive,
71+
/* origin */ target.origin,
72+
),
6673
);
6774
}
6875

@@ -79,7 +86,10 @@ export function accessIndex(target: Snippet, indexArg: Snippet | number): Snippe
7986
(target.value.matrix.dataType as BaseData).type as keyof typeof indexableTypeToResult
8087
];
8188

82-
return snip(stitch`${target.value.matrix}[${index}]`, propType, /* origin */ target.origin);
89+
return withSideEffects(
90+
target.possibleSideEffects || index.possibleSideEffects,
91+
snip(stitch`${target.value.matrix}[${index}]`, propType, /* origin */ target.origin),
92+
);
8393
}
8494

8595
// matrix

packages/typegpu/src/tgsl/conversion.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { UnknownData } from '../data/dataTypes.ts';
33
import { undecorate } from '../data/dataTypes.ts';
44
import { derefSnippet, RefOperator } from '../data/ref.ts';
55
import { schemaCallWrapperGPU } from '../data/schemaCallWrapper.ts';
6-
import { snip, type Snippet } from '../data/snippet.ts';
6+
import { snip, withSideEffects, type Snippet } from '../data/snippet.ts';
77
import {
88
type AbstractFloat,
99
type AnyWgslData,
@@ -245,11 +245,14 @@ function applyActionToSnippet(
245245
return snippet;
246246
}
247247

248-
return snip(
249-
snippet.value,
250-
targetType,
251-
// if it was a ref, then it's still a ref
252-
/* origin */ snippet.origin,
248+
return withSideEffects(
249+
snippet.possibleSideEffects,
250+
snip(
251+
snippet.value,
252+
targetType,
253+
// if it was a ref, then it's still a ref
254+
/* origin */ snippet.origin,
255+
),
253256
);
254257
}
255258

0 commit comments

Comments
 (0)