Skip to content

Commit ea5b2db

Browse files
authored
feat: More convenient TgpuVertexFn.AutoIn and TgpuVertexFn.AutoOut types (#2282)
1 parent 35a38e7 commit ea5b2db

File tree

10 files changed

+81
-66
lines changed

10 files changed

+81
-66
lines changed

apps/typegpu-docs/src/examples/algorithms/jump-flood-voronoi/index.ts

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import { defineControls } from '../../common/defineControls.ts';
66
const root = await tgpu.init();
77

88
const canvas = document.querySelector('canvas') as HTMLCanvasElement;
9-
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
109

1110
const context = root.configureContext({ canvas });
1211

@@ -176,17 +175,12 @@ const jumpFlood = root.createGuardedComputePipeline((x, y) => {
176175
std.textureStore(pingPongLayout.$.writeView, d.vec2i(x, y), 1, d.vec4f(bestSample.coord, 0, 0));
177176
});
178177

179-
const voronoiFrag = tgpu.fragmentFn({
180-
in: { uv: d.vec2f },
181-
out: d.vec4f,
182-
})(({ uv }) =>
183-
std.textureSample(colorSampleLayout.$.floodTexture, colorSampleLayout.$.sampler, uv),
184-
);
185-
186178
const voronoiPipeline = root.createRenderPipeline({
187179
vertex: common.fullScreenTriangle,
188-
fragment: voronoiFrag,
189-
targets: { format: presentationFormat },
180+
fragment: ({ uv }) => {
181+
'use gpu';
182+
return std.textureSample(colorSampleLayout.$.floodTexture, colorSampleLayout.$.sampler, uv);
183+
},
190184
});
191185

192186
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

apps/typegpu-docs/src/examples/image-processing/background-segmentation/shaders.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import tgpu, { d, std } from 'typegpu';
1+
import tgpu, { d, std, type TgpuFragmentFn } from 'typegpu';
22
import { MODEL_HEIGHT, MODEL_WIDTH } from './model.ts';
33
import {
44
blockDim,
@@ -94,9 +94,8 @@ export const computeFn = tgpu.computeFn({
9494
}
9595
});
9696

97-
export const fragmentFn = (input: { uv: d.v2f }) => {
97+
export const fragmentFn = ({ uv }: TgpuFragmentFn.AutoIn<{ uv: d.v2f }>) => {
9898
'use gpu';
99-
const uv = input.uv;
10099
const originalColor = std.textureSampleBaseClampToEdge(
101100
drawWithMaskLayout.$.inputTexture,
102101
drawWithMaskLayout.$.sampler,

apps/typegpu-docs/src/examples/rendering/box-raytracing/index.ts

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { linearToSrgb, srgbToLinear } from '@typegpu/color';
2-
import tgpu, { d } from 'typegpu';
3-
import { add, discard, div, max, min, mul, normalize, pow, sub } from 'typegpu/std';
2+
import tgpu, { d, type TgpuFragmentFn, type TgpuVertexFn } from 'typegpu';
3+
import { discard, max, min, mul, normalize, pow, sub } from 'typegpu/std';
44
import { mat4 } from 'wgpu-matrix';
55
import { defineControls } from '../../common/defineControls.ts';
66

@@ -133,41 +133,39 @@ const getBoxIntersection = tgpu
133133
}`)
134134
.$uses({ IntersectionStruct });
135135

136-
const Varying = {
137-
rayWorldOrigin: d.vec3f,
138-
};
139-
140-
const mainVertex = tgpu.vertexFn({
141-
in: { vertexIndex: d.builtin.vertexIndex },
142-
out: { pos: d.builtin.position, ...Varying },
143-
})((input) => {
136+
const mainVertex = ({ $vertexIndex: vid }: TgpuVertexFn.AutoIn<{}>) => {
137+
'use gpu';
144138
const pos = [d.vec2f(-1, -1), d.vec2f(3, -1), d.vec2f(-1, 3)];
145139

146-
const rayWorldOrigin = mul(uniforms.$.invViewMatrix, d.vec4f(0, 0, 0, 1)).xyz;
140+
const rayWorldOrigin = (uniforms.$.invViewMatrix * d.vec4f(0, 0, 0, 1)).xyz;
147141

148-
return { pos: d.vec4f(pos[input.vertexIndex], 0.0, 1.0), rayWorldOrigin };
149-
});
142+
return {
143+
$position: d.vec4f(pos[vid], 0, 1),
144+
rayWorldOrigin,
145+
} satisfies TgpuVertexFn.AutoOut;
146+
};
150147

151-
const fragmentFunction = tgpu.fragmentFn({
152-
in: { position: d.builtin.position, ...Varying },
153-
out: d.vec4f,
154-
})((input) => {
155-
const boxSize3 = d.vec3f(d.f32(uniforms.$.boxSize));
156-
const halfBoxSize3 = mul(0.5, boxSize3);
157-
const halfCanvasDims = mul(0.5, uniforms.$.canvasDims);
148+
const fragmentFunction = ({
149+
$position,
150+
rayWorldOrigin,
151+
}: TgpuFragmentFn.AutoIn<{ rayWorldOrigin: d.v3f }>) => {
152+
'use gpu';
153+
const boxSize3 = d.vec3f(uniforms.$.boxSize);
154+
const halfBoxSize3 = 0.5 * boxSize3;
155+
const halfCanvasDims = 0.5 * uniforms.$.canvasDims;
158156

159157
const minDim = min(uniforms.$.canvasDims.x, uniforms.$.canvasDims.y);
160-
const viewCoords = div(sub(input.position.xy, halfCanvasDims), minDim);
158+
const viewCoords = ($position.xy - halfCanvasDims) / minDim;
161159

162160
const ray = Ray({
163-
origin: input.rayWorldOrigin,
164-
direction: mul(uniforms.$.invViewMatrix, d.vec4f(normalize(d.vec3f(viewCoords, 1)), 0)).xyz,
161+
origin: rayWorldOrigin,
162+
direction: (uniforms.$.invViewMatrix * d.vec4f(normalize(d.vec3f(viewCoords, 1)), 0)).xyz,
165163
});
166164

167165
const bigBoxIntersection = getBoxIntersection(
168166
AxisAlignedBounds({
169-
min: mul(-1, halfBoxSize3),
170-
max: add(cubeSize, halfBoxSize3),
167+
min: -1 * halfBoxSize3,
168+
max: cubeSize + halfBoxSize3,
171169
}),
172170
ray,
173171
);
@@ -188,12 +186,12 @@ const fragmentFunction = tgpu.fragmentFn({
188186
continue;
189187
}
190188

191-
const ijkScaled = d.vec3f(d.f32(i), d.f32(j), d.f32(k));
189+
const ijkScaled = d.vec3f(i, j, k);
192190

193191
const intersection = getBoxIntersection(
194192
AxisAlignedBounds({
195-
min: sub(ijkScaled, halfBoxSize3),
196-
max: add(ijkScaled, halfBoxSize3),
193+
min: ijkScaled - halfBoxSize3,
194+
max: ijkScaled + halfBoxSize3,
197195
}),
198196
ray,
199197
);
@@ -202,25 +200,25 @@ const fragmentFunction = tgpu.fragmentFn({
202200
const boxDensity =
203201
max(0, intersection.tMax - intersection.tMin) * pow(uniforms.$.materialDensity, 2);
204202
density += boxDensity;
205-
invColor = add(invColor, mul(boxDensity, div(d.vec3f(1), boxMatrix.$[i][j][k].albedo)));
203+
invColor += boxDensity * (1 / boxMatrix.$[i][j][k].albedo);
206204
intersectionFound = true;
207205
}
208206
}
209207
}
210208
}
211209

212-
const linear = div(d.vec3f(1), invColor);
210+
const linear = 1 / invColor;
213211
const srgb = linearToSrgb(linear);
214212
const gamma = 2.2;
215-
const corrected = pow(srgb, d.vec3f(1.0 / gamma));
213+
const corrected = pow(srgb, d.vec3f(1 / gamma));
216214

217215
if (intersectionFound) {
218-
return mul(min(density, 1), d.vec4f(min(corrected, d.vec3f(1)), 1));
216+
return min(density, 1) * d.vec4f(min(corrected, d.vec3f(1)), 1);
219217
}
220218

221219
discard();
222220
return d.vec4f();
223-
});
221+
};
224222

225223
// pipeline
226224

apps/typegpu-docs/src/examples/tests/log-test/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import tgpu, { d, std, type AutoVertexIn } from 'typegpu';
1+
import tgpu, { d, std, type TgpuVertexFn } from 'typegpu';
22
import { defineControls } from '../../common/defineControls.ts';
33

44
const root = await tgpu.init({
@@ -256,7 +256,7 @@ export const controls = defineControls({
256256
console.log(n);
257257
};
258258

259-
const vs = ({ $vertexIndex }: AutoVertexIn<{}>) => {
259+
const vs = ({ $vertexIndex }: TgpuVertexFn.AutoIn<{}>) => {
260260
'use gpu';
261261
const positions = [d.vec2f(0, 0.5), d.vec2f(-0.5, -0.5), d.vec2f(0.5, -0.5)];
262262
myLog(6);

apps/typegpu-docs/tests/individual-example-tests/box-raytracing.test.ts

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ describe('box raytracing example', () => {
2929
3030
@group(0) @binding(0) var<uniform> uniforms: Uniforms;
3131
32-
struct mainVertex_Output {
33-
@builtin(position) pos: vec4f,
32+
struct VertexOut {
33+
@builtin(position) position: vec4f,
3434
@location(0) rayWorldOrigin: vec3f,
3535
}
3636
37-
@vertex fn mainVertex(@builtin(vertex_index) _arg_vertexIndex: u32) -> mainVertex_Output {
38-
var pos = array<vec2f, 3>(vec2f(-1), vec2f(3, -1), vec2f(-1, 3));
39-
var rayWorldOrigin = (uniforms.invViewMatrix * vec4f(0, 0, 0, 1)).xyz;
40-
return mainVertex_Output(vec4f(pos[_arg_vertexIndex], 0f, 1f), rayWorldOrigin);
37+
struct VertexIn {
38+
@builtin(vertex_index) vertexIndex: u32,
4139
}
4240
43-
struct fragmentFunction_Input {
44-
@location(0) rayWorldOrigin: vec3f,
41+
@vertex fn mainVertex(_arg_0: VertexIn) -> VertexOut {
42+
var pos = array<vec2f, 3>(vec2f(-1), vec2f(3, -1), vec2f(-1, 3));
43+
var rayWorldOrigin = (uniforms.invViewMatrix * vec4f(0, 0, 0, 1)).xyz;
44+
return VertexOut(vec4f(pos[_arg_0.vertexIndex], 0f, 1f), rayWorldOrigin);
4545
}
4646
4747
struct Ray {
@@ -130,12 +130,17 @@ describe('box raytracing example', () => {
130130
return select((12.92f * linear), ((1.055f * pow(linear, vec3f(0.4166666567325592))) - vec3f(0.054999999701976776)), (linear > vec3f(0.0031308000907301903)));
131131
}
132132
133-
@fragment fn fragmentFunction(_arg_0: fragmentFunction_Input, @builtin(position) _arg_position: vec4f) -> @location(0) vec4f {
133+
struct FragmentIn {
134+
@builtin(position) position: vec4f,
135+
@location(0) rayWorldOrigin: vec3f,
136+
}
137+
138+
@fragment fn fragmentFunction(_arg_0: FragmentIn) -> @location(0) vec4f {
134139
var boxSize3 = vec3f(uniforms.boxSize);
135140
var halfBoxSize3 = (0.5f * boxSize3);
136141
var halfCanvasDims = (0.5f * uniforms.canvasDims);
137142
let minDim = min(uniforms.canvasDims.x, uniforms.canvasDims.y);
138-
var viewCoords = ((_arg_position.xy - halfCanvasDims) / minDim);
143+
var viewCoords = ((_arg_0.position.xy - halfCanvasDims) / minDim);
139144
var ray = Ray(_arg_0.rayWorldOrigin, (uniforms.invViewMatrix * vec4f(normalize(vec3f(viewCoords, 1f)), 0f)).xyz);
140145
var bigBoxIntersection = getBoxIntersection(AxisAlignedBounds((-1f * halfBoxSize3), (vec3f(7) + halfBoxSize3)), ray);
141146
if (!bigBoxIntersection.intersects) {
@@ -156,13 +161,13 @@ describe('box raytracing example', () => {
156161
if (intersection.intersects) {
157162
let boxDensity = (max(0f, (intersection.tMax - intersection.tMin)) * pow(uniforms.materialDensity, 2f));
158163
density += boxDensity;
159-
invColor = (invColor + (boxDensity * (vec3f(1) / boxMatrix[i][j][k].albedo)));
164+
invColor += (boxDensity * (1f / boxMatrix[i][j][k].albedo));
160165
intersectionFound = true;
161166
}
162167
}
163168
}
164169
}
165-
var linear = (vec3f(1) / invColor);
170+
var linear = (1f / invColor);
166171
var srgb = linearToSrgb(linear);
167172
const gamma = 2.2;
168173
var corrected = pow(srgb, vec3f((1f / gamma)));

apps/typegpu-docs/tests/individual-example-tests/jump-flood-voronoi.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ describe('jump flood (voronoi) example', () => {
8686
return fullScreenTriangle_Output(vec4f(pos[vertexIndex], 0, 1), uv[vertexIndex]);
8787
}
8888
89-
struct voronoiFrag_Input {
90-
@location(0) uv: vec2f,
91-
}
92-
9389
@group(0) @binding(0) var floodTexture: texture_2d<f32>;
9490
9591
@group(0) @binding(1) var sampler_1: sampler;
9692
97-
@fragment fn voronoiFrag(_arg_0: voronoiFrag_Input) -> @location(0) vec4f {
93+
struct FragmentIn {
94+
@location(0) uv: vec2f,
95+
}
96+
97+
@fragment fn fragment(_arg_0: FragmentIn) -> @location(0) vec4f {
9898
return textureSample(floodTexture, sampler_1, _arg_0.uv);
9999
}
100100

packages/typegpu/src/core/function/autoIO.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ import { builtin, type OmitBuiltins } from '../../builtin.ts';
22
import { AutoStruct } from '../../data/autoStruct.ts';
33
import type { ResolvedSnippet } from '../../data/snippet.ts';
44
import { vec4f } from '../../data/vector.ts';
5+
import type { FormatToWGSLType } from '../../data/vertexFormatData.ts';
56
import type { BaseData, v4f } from '../../data/wgslTypes.ts';
67
import { getName, setName } from '../../shared/meta.ts';
78
import type { InferGPU, InferGPURecord, InferRecord } from '../../shared/repr.ts';
89
import { $internal, $resolve } from '../../shared/symbols.ts';
10+
import type { Assume } from '../../shared/utilityTypes.ts';
11+
import type { TgpuVertexAttrib } from '../../shared/vertexFormat.ts';
912
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
1013
import { shaderStageSlot } from '../slot/internalSlots.ts';
1114
import { createFnCore, type FnCore } from './fnCore.ts';
@@ -20,6 +23,12 @@ const builtinVertexIn = {
2023

2124
export type AutoVertexIn<T extends AnyAutoCustoms> = T & InferRecord<typeof builtinVertexIn>;
2225

26+
export type _AutoVertexIn<T> = AutoVertexIn<{
27+
[Key in keyof T]: T[Key] extends TgpuVertexAttrib
28+
? InferGPU<FormatToWGSLType<T[Key]['format']>>
29+
: Assume<T[Key], InferGPU<BaseIOData>>;
30+
}>;
31+
2332
const builtinVertexOut = {
2433
$clipDistances: builtin.clipDistances,
2534
$position: builtin.position,

packages/typegpu/src/core/function/tgpuFragmentFn.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import type { Prettify } from '../../shared/utilityTypes.ts';
2222
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
2323
import { addReturnTypeToExternals } from '../resolve/externals.ts';
2424
import { shaderStageSlot } from '../slot/internalSlots.ts';
25+
import type { AnyAutoCustoms, AutoFragmentIn, AutoFragmentOut } from './autoIO.ts';
2526
import { createFnCore, type FnCore } from './fnCore.ts';
2627
import type { BaseIOData, Implementation, InferIO, IOLayout, IORecord } from './fnTypes.ts';
2728
import { createIoSchema, type IOLayoutToSchema, separateBuiltins } from './ioSchema.ts';
@@ -118,6 +119,8 @@ export declare namespace TgpuFragmentFn {
118119
// readable, and refactoring to use a builtin argument is too much hassle.
119120
type In = Record<string, BaseData>;
120121
type Out = Record<string, BaseData> | BaseData;
122+
type AutoIn<T extends AnyAutoCustoms> = AutoFragmentIn<T>;
123+
type AutoOut<T extends AnyAutoCustoms = AnyAutoCustoms> = AutoFragmentOut<T>;
121124
}
122125

123126
export function fragmentFn<FragmentOut extends FragmentOutConstrained>(options: {

packages/typegpu/src/core/function/tgpuVertexFn.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { $getNameForward, $internal, $resolve } from '../../shared/symbols.ts';
77
import type { Prettify } from '../../shared/utilityTypes.ts';
88
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
99
import { shaderStageSlot } from '../slot/internalSlots.ts';
10+
import type { _AutoVertexIn, AnyAutoCustoms, AutoVertexOut } from './autoIO.ts';
1011
import { createFnCore, type FnCore } from './fnCore.ts';
1112
import type {
1213
BaseIOData,
@@ -83,6 +84,8 @@ export interface TgpuVertexFn<
8384
export declare namespace TgpuVertexFn {
8485
type In = BaseData | Record<string, BaseData>;
8586
type Out = Record<string, BaseData>;
87+
type AutoIn<T> = _AutoVertexIn<T>;
88+
type AutoOut<T extends AnyAutoCustoms = AnyAutoCustoms> = AutoVertexOut<T>;
8689
}
8790

8891
export function vertexFn<VertexOut extends VertexOutConstrained>(options: {

packages/typegpu/src/indexNamedExports.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ export type { TgpuVertexFn, TgpuVertexFnShell } from './core/function/tgpuVertex
110110
export type { TgpuFragmentFn, TgpuFragmentFnShell } from './core/function/tgpuFragmentFn.ts';
111111
export type { TgpuComputeFn, TgpuComputeFnShell } from './core/function/tgpuComputeFn.ts';
112112
export type {
113+
/** @deprecated use TgpuFragmentFn.AutoIn */
113114
AutoFragmentIn,
115+
/** @deprecated use TgpuFragmentFn.AutoOut */
114116
AutoFragmentOut,
115-
AutoVertexIn,
117+
/** @deprecated use TgpuVertexFn.AutoIn */
118+
_AutoVertexIn as AutoVertexIn,
119+
/** @deprecated use TgpuVertexFn.AutoOut */
116120
AutoVertexOut,
117121
} from './core/function/autoIO.ts';
118122
export type { TgpuDeclare } from './core/declare/tgpuDeclare.ts';

0 commit comments

Comments
 (0)