Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ describe('game of life example', () => {

fn wrappedCallback(x: u32, y: u32, _arg_2: u32) {
randSeed2(((vec2f(f32(x), f32(y)) / f32(gameSizeUniform)) * timeUniform));
textureStore(next, vec2u(x, y), vec4u(u32(select(0, 1, (randFloat01() > 0.5f))), 0u, 0u, 0u));
textureStore(next, vec2u(x, y), vec4u(u32(select(0i, 1i, (randFloat01() > 0.5f))), 0u, 0u, 0u));
}

@compute @workgroup_size(16, 16, 1) fn mainCompute(@builtin(global_invocation_id) id: vec3u) {
Expand Down Expand Up @@ -146,7 +146,7 @@ describe('game of life example', () => {
let current_1 = readTile(lx, ly);
let neighbors = countNeighborsInTile(lx, ly);
let nextAlive = golNextState((current_1 != 0u), neighbors);
textureStore(next, gid.xy, vec4u(u32(select(0, 1, nextAlive)), 0u, 0u, 0u));
textureStore(next, gid.xy, vec4u(u32(select(0i, 1i, nextAlive)), 0u, 0u, 0u));
}

@group(0) @binding(0) var<uniform> gameSizeUniform: u32;
Expand Down Expand Up @@ -199,7 +199,7 @@ describe('game of life example', () => {
let current_1 = readTile(lx, ly);
let neighbors = countNeighborsInTile(lx, ly);
let nextAlive = golNextState((current_1 != 0u), neighbors);
textureStore(next, gid.xy, vec4u(u32(select(0, 1, nextAlive)), 0u, 0u, 0u));
textureStore(next, gid.xy, vec4u(u32(select(0i, 1i, nextAlive)), 0u, 0u, 0u));
}

struct fullScreenTriangle_Output {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ describe('jelly switch example', () => {
let ndc = vec2f(((_arg_0.uv.x * 2f) - 1f), -(((_arg_0.uv.y * 2f) - 1f)));
let ray = getRay(ndc);
let color = rayMarch(ray.origin, ray.direction, _arg_0.uv);
let exposure = select(1.5, 2., (darkModeUniform == 1u));
let exposure = select(1.5f, 2f, (darkModeUniform == 1u));
return vec4f(tanh((color.rgb * exposure)), 1f);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ describe('probability distribution plot example', () => {
let face = u32((sample() * 6f));
let axis = (face % 3u);
var result = vec3f();
result[axis] = f32(select(0, 1, (face > 2u)));
result[axis] = f32(select(0i, 1i, (face > 2u)));
result[((axis + 1u) % 3u)] = sample();
result[((axis + 2u) % 3u)] = sample();
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ describe('ripple-cube example', () => {
const cellSize = 0.0047169811320754715;
let p = ((vec3f(f32(x), f32(y), f32(z)) + 0.5f) * cellSize);
let r = (timeUniform * 0.15f);
let iterCount = select(5, 11, (extendedRippleUniform == 1u));
let iterCount = select(5i, 11i, (extendedRippleUniform == 1u));
var shellD = 1e+10f;
for (var ix = 0; (ix < iterCount); ix++) {
for (var iy = 0; (iy < iterCount); iy++) {
Expand Down
18 changes: 15 additions & 3 deletions packages/typegpu/src/core/function/createCallableSchema.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { type MapValueToSnippet, type ResolvedSnippet, snip } from '../../data/snippet.ts';
import {
type MapValueToSnippet,
noSideEffects,
type ResolvedSnippet,
snip,
type Snippet,
} from '../../data/snippet.ts';
import { type BaseData, isPtr } from '../../data/wgslTypes.ts';
import { setName } from '../../shared/meta.ts';
import { $gpuCallable } from '../../shared/symbols.ts';
Expand Down Expand Up @@ -51,10 +57,11 @@ export function callableSchema<T extends AnyFn>(options: CallableSchemaOptions<T
return tryConvertSnippet(ctx, s, argType, false);
}) as MapValueToSnippet<Parameters<T>>;

let result: Snippet;
if (converted.every((s) => isKnownAtComptime(s))) {
ctx.pushMode(new NormalState());
try {
return snip(
result = snip(
options.normalImpl(...(converted.map((s) => s.value) as never[])),
options.schema(),
// Functions give up ownership of their return value
Expand All @@ -63,9 +70,14 @@ export function callableSchema<T extends AnyFn>(options: CallableSchemaOptions<T
} finally {
ctx.popMode('normal');
}
} else {
result = options.codegenImpl(ctx, converted);
}

return options.codegenImpl(ctx, converted);
if (!args.some((a) => a.possibleSideEffects)) {
return noSideEffects(result);
}
return result;
},
};

Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/core/slot/accessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ abstract class AccessorBase<
try {
// Doing a deep copy each time so that we don't have to deal with refs
const cloned = schemaCallWrapper(this.schema, value);
return snip(cloned, this.schema, 'constant');
return snip(cloned, this.schema, 'constant', /* possibleSideEffects */ false);
} finally {
ctx.popMode('normal');
}
Expand Down
50 changes: 47 additions & 3 deletions packages/typegpu/src/data/snippet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export interface Snippet {
*/
readonly dataType: BaseData | UnknownData;
readonly origin: Origin;
readonly possibleSideEffects: boolean;
}

export interface ResolvedSnippet extends Snippet {
Expand All @@ -105,11 +106,18 @@ class SnippetImpl implements Snippet {
readonly value: unknown;
readonly dataType: BaseData | UnknownData;
readonly origin: Origin;
readonly possibleSideEffects: boolean;

constructor(value: unknown, dataType: BaseData | UnknownData, origin: Origin) {
constructor(
value: unknown,
dataType: BaseData | UnknownData,
origin: Origin,
possibleSideEffects: boolean,
) {
this.value = value;
this.dataType = dataType;
this.origin = origin;
this.possibleSideEffects = possibleSideEffects;
}
}

Expand All @@ -121,12 +129,23 @@ export function isSnippetNumeric(snippet: Snippet) {
return isNumericSchema(snippet.dataType);
}

export function snip(value: string, dataType: BaseData, origin: Origin): ResolvedSnippet;
export function snip(value: unknown, dataType: BaseData | UnknownData, origin: Origin): Snippet;
export function snip(
value: string,
dataType: BaseData,
origin: Origin,
possibleSideEffects?: boolean,
): ResolvedSnippet;
export function snip(
value: unknown,
dataType: BaseData | UnknownData,
origin: Origin,
possibleSideEffects?: boolean,
): Snippet;
export function snip(
value: unknown,
dataType: BaseData | UnknownData,
origin: Origin,
possibleSideEffects: boolean = true,
): Snippet | ResolvedSnippet {
if (DEV && isSnippet(value)) {
// An early error, but not worth checking every time in production
Expand All @@ -138,5 +157,30 @@ export function snip(
// We don't care about attributes in snippet land, so we discard that information.
undecorate(dataType as BaseData),
origin,
possibleSideEffects,
);
}

export function withDataType(
dataType: BaseData | UnknownData,
snippet: ResolvedSnippet,
): ResolvedSnippet;
export function withDataType(dataType: BaseData | UnknownData, snippet: Snippet): Snippet;
export function withDataType(dataType: BaseData | UnknownData, snippet: Snippet): Snippet {
return new SnippetImpl(snippet.value, dataType, snippet.origin, snippet.possibleSideEffects);
}

export function withSideEffects(
possibleSideEffects: boolean,
snippet: ResolvedSnippet,
): ResolvedSnippet;
export function withSideEffects(possibleSideEffects: boolean, snippet: Snippet): Snippet;
export function withSideEffects(possibleSideEffects: boolean, snippet: Snippet): Snippet {
return new SnippetImpl(snippet.value, snippet.dataType, snippet.origin, possibleSideEffects);
}

export function noSideEffects(snippet: ResolvedSnippet): ResolvedSnippet;
export function noSideEffects(snippet: Snippet): Snippet;
export function noSideEffects(snippet: Snippet): Snippet {
return withSideEffects(/* possibleSideEffects */ false, snippet);
}
Comment thread
iwoplaza marked this conversation as resolved.
65 changes: 61 additions & 4 deletions packages/typegpu/src/std/boolean.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import { dualImpl } from '../core/function/dualImpl.ts';
import { stitch } from '../core/resolve/stitch.ts';
import { bool, f32 } from '../data/numeric.ts';
import { bool, f16, f32, i32, u32 } from '../data/numeric.ts';
import { isSnippetNumeric, snip } from '../data/snippet.ts';
import { vec2b, vec3b, vec4b } from '../data/vector.ts';
import {
vec2b,
vec2f,
vec2h,
vec2i,
vec2u,
vec3b,
vec3f,
vec3h,
vec3i,
vec3u,
vec4b,
vec4f,
vec4h,
vec4i,
vec4u,
} from '../data/vector.ts';
import { VectorOps } from '../data/vectorOps.ts';
import {
type AnyBooleanVecInstance,
Expand Down Expand Up @@ -374,6 +390,29 @@ function cpuSelect<T extends number | boolean | AnyVecInstance>(
);
}

export const validSelectBranchTypes: AnyWgslData[] = [
f32,
f16,
i32,
u32,
bool,
vec2f,
vec3f,
vec4f,
vec2h,
vec3h,
vec4h,
vec2i,
vec3i,
vec4i,
vec2u,
vec3u,
vec4u,
vec2b,
vec3b,
vec4b,
];

/**
* Returns `t` if `cond` is `true`, and `f` otherwise.
* Component-wise if `cond` is a vector.
Expand All @@ -386,9 +425,27 @@ function cpuSelect<T extends number | boolean | AnyVecInstance>(
export const select = dualImpl({
name: 'select',
signature: (f, t, cond) => {
const [uf, ut] = unify([f, t]) ?? ([f, t] as const);
const [uf, ut] = unify([f, t], validSelectBranchTypes) ?? ([f, t] as const);
return { argTypes: [uf, ut, cond], returnType: uf };
},
normalImpl: cpuSelect,
codegenImpl: (_ctx, [f, t, cond]) => stitch`select(${f}, ${t}, ${cond})`,
codegenImpl: (ctx, [f, t, cond]) => {
const result = stitch`select(${f}, ${t}, ${cond})`;
if (
!validSelectBranchTypes.includes(f.dataType as AnyWgslData) ||
!validSelectBranchTypes.includes(t.dataType as AnyWgslData)
) {
throw new Error(
`'${result}' is invalid, std.select requires both branches to be either scalars or vectors.`,
);
}
if (f.dataType !== t.dataType) {
const fStr = ctx.resolve(f.dataType);
const tStr = ctx.resolve(t.dataType);
throw new Error(
`'${result}' is invalid, std.select requires both branches to be the same type, got [${fStr.value}, ${tStr.value}].`,
);
}
return result;
},
});
12 changes: 10 additions & 2 deletions packages/typegpu/src/tgsl/accessIndex.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { stitch } from '../core/resolve/stitch.ts';
import { isDisarray, MatrixColumnsAccess } from '../data/dataTypes.ts';
import { derefSnippet } from '../data/ref.ts';
import { type Origin, snip, type Snippet } from '../data/snippet.ts';
import { snip } from '../data/snippet.ts';
import type { Origin, Snippet } from '../data/snippet.ts';
import { vec2f, vec3f, vec4f } from '../data/vector.ts';
import { type BaseData, isPtr, isVec, isWgslArray, isWgslStruct } from '../data/wgslTypes.ts';
import { isKnownAtComptime } from '../types.ts';
Expand Down Expand Up @@ -51,6 +52,7 @@ export function accessIndex(target: Snippet, indexArg: Snippet | number): Snippe
: stitch`${target}[${index}]`,
elementType,
/* origin */ origin,
target.possibleSideEffects || index.possibleSideEffects,
);
}

Expand All @@ -63,6 +65,7 @@ export function accessIndex(target: Snippet, indexArg: Snippet | number): Snippe
: stitch`${target}[${index}]`,
target.dataType.primitive,
/* origin */ target.origin,
target.possibleSideEffects || index.possibleSideEffects,
);
}

Expand All @@ -79,7 +82,12 @@ export function accessIndex(target: Snippet, indexArg: Snippet | number): Snippe
(target.value.matrix.dataType as BaseData).type as keyof typeof indexableTypeToResult
];

return snip(stitch`${target.value.matrix}[${index}]`, propType, /* origin */ target.origin);
return snip(
stitch`${target.value.matrix}[${index}]`,
propType,
/* origin */ target.origin,
target.possibleSideEffects || index.possibleSideEffects,
);
}

// matrix
Expand Down
Loading
Loading