Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/typegpu/src/core/function/autoIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ export class AutoFragmentFn implements SelfResolvable {
setName(impl, 'fragmentFn');
}
this.#core = createFnCore(impl, 'fragment');
this.autoIn = new AutoStruct({ ...builtinFragmentIn, ...varyings }, undefined, locations);
this.autoIn = new AutoStruct({ ...builtinFragmentIn, ...varyings }, undefined, locations, {
autoInterpolateIntegerVaryings: true,
});
setName(this.autoIn, 'FragmentIn');
this.autoOut = new AutoStruct(builtinFragmentOut, vec4f);
setName(this.autoOut, 'FragmentOut');
Expand Down Expand Up @@ -134,7 +136,9 @@ export class AutoVertexFn implements SelfResolvable {
this.#core = createFnCore(impl, 'vertex');
this.autoIn = new AutoStruct({ ...builtinVertexIn, ...attribs }, undefined, locations);
setName(this.autoIn, 'VertexIn');
this.autoOut = new AutoStruct(builtinVertexOut, undefined);
this.autoOut = new AutoStruct(builtinVertexOut, undefined, undefined, {
autoInterpolateIntegerVaryings: true,
});
setName(this.autoOut, 'VertexOut');
}

Expand Down
55 changes: 48 additions & 7 deletions packages/typegpu/src/core/function/ioSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@ import {
type Decorate,
type HasCustomLocation,
type IsBuiltin,
interpolate,
location,
} from '../../data/attributes.ts';
import { isBuiltin } from '../../data/attributes.ts';
import { getCustomLocation, isData } from '../../data/dataTypes.ts';
import { getCustomLocation, isData, undecorate } from '../../data/dataTypes.ts';
import { INTERNAL_createStruct } from '../../data/struct.ts';
import { type BaseData, isVoid, type Location, type WgslStruct } from '../../data/wgslTypes.ts';
import {
type BaseData,
type FlatInterpolatableData,
isInterpolateAttrib,
isVoid,
type Location,
type WgslStruct,
} from '../../data/wgslTypes.ts';
import type { SeparatedEntryArgs } from './fnTypes.ts';

export type WithLocations<T extends Record<string, BaseData>> = {
Expand All @@ -28,9 +36,39 @@ export type IOLayoutToSchema<T> = T extends BaseData
? void
: never;

const integerVaryingTypes = new Set([
'i32',
'u32',
'vec2i',
'vec2u',
'vec3i',
'vec3u',
'vec4i',
'vec4u',
]);

export type IoSchemaOptions = {
readonly autoInterpolateIntegerVaryings?: boolean;
};

function hasInterpolation(data: BaseData) {
return (data as { attribs?: unknown[] }).attribs?.some(isInterpolateAttrib) ?? false;
}

function maybeInterpolateIntegerVarying(data: BaseData, options: IoSchemaOptions) {
if (!options.autoInterpolateIntegerVaryings || hasInterpolation(data)) {
return data;
}

return integerVaryingTypes.has(undecorate(data).type)
? interpolate('flat', data as FlatInterpolatableData)
: data;
}

export function withLocations<T extends BaseData>(
members: Record<string, T> | undefined,
locations: Record<string, number> = {},
options: IoSchemaOptions = {},
): Record<string, BaseData> {
let nextLocation = 0;
const usedCustomLocations = new Set<number>();
Expand All @@ -54,28 +92,30 @@ export function withLocations<T extends BaseData>(
// skipping builtins
return [key, member];
}
const memberWithInterpolation = maybeInterpolateIntegerVarying(member, options);

if (getCustomLocation(member) !== undefined) {
// this member is already marked
return [key, member];
return [key, memberWithInterpolation];
}

if (locations[key]) {
// location has been determined by a previous procedure
return [key, location(locations[key], member)];
return [key, location(locations[key], memberWithInterpolation)];
}

while (usedCustomLocations.has(nextLocation)) {
nextLocation++;
}
return [key, location(nextLocation++, member)];
return [key, location(nextLocation++, memberWithInterpolation)];
}),
);
}

export function separateBuiltins(
schema: Record<string, BaseData>,
locations: Record<string, number> = {},
options: IoSchemaOptions = {},
): SeparatedEntryArgs {
const positionalArgs: SeparatedEntryArgs['positionalArgs'] = [];
const dataFields: Record<string, BaseData> = {};
Expand All @@ -90,7 +130,7 @@ export function separateBuiltins(

const dataSchema =
Object.keys(dataFields).length > 0
? INTERNAL_createStruct(withLocations(dataFields, locations), /* isAbstruct */ false)
? INTERNAL_createStruct(withLocations(dataFields, locations, options), /* isAbstruct */ false)
: undefined;

return { dataSchema, positionalArgs };
Expand All @@ -105,6 +145,7 @@ export function separateAllAsPositional(schema: Record<string, BaseData>): Separ
export function createIoSchema<T extends BaseData | Record<string, BaseData>>(
layout: T,
locations: Record<string, number> = {},
options: IoSchemaOptions = {},
) {
return (
isData(layout)
Expand All @@ -116,7 +157,7 @@ export function createIoSchema<T extends BaseData | Record<string, BaseData>>(
? layout
: location(0, layout)
: INTERNAL_createStruct(
withLocations(layout as Record<string, BaseData>, locations),
withLocations(layout as Record<string, BaseData>, locations, options),
/* isAbstruct */ false,
Comment on lines +160 to 161
)
) as IOLayoutToSchema<T>;
Expand Down
4 changes: 3 additions & 1 deletion packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ function createFragmentFn(
},

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
const entryInput = separateBuiltins(shell.in ?? {}, ctx.varyingLocations ?? {});
const entryInput = separateBuiltins(shell.in ?? {}, ctx.varyingLocations ?? {}, {
autoInterpolateIntegerVaryings: true,
});

if (entryInput.dataSchema && isNamable(entryInput.dataSchema)) {
entryInput.dataSchema.$name(`${getName(this) ?? ''}_Input`);
Expand Down
6 changes: 3 additions & 3 deletions packages/typegpu/src/core/function/tgpuVertexFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ function createVertexFn(
},

[$resolve](ctx: ResolutionCtx): ResolvedSnippet {
const outputWithLocation = createIoSchema(shell.out, ctx.varyingLocations).$name(
`${getName(this) ?? ''}_Output`,
);
const outputWithLocation = createIoSchema(shell.out, ctx.varyingLocations, {
autoInterpolateIntegerVaryings: true,
}).$name(`${getName(this) ?? ''}_Output`);

if (typeof implementation === 'string') {
core.applyExternals({ Out: outputWithLocation });
Expand Down
6 changes: 5 additions & 1 deletion packages/typegpu/src/data/autoStruct.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createIoSchema } from '../core/function/ioSchema.ts';
import { createIoSchema, type IoSchemaOptions } from '../core/function/ioSchema.ts';
import { validateProp } from '../nameUtils.ts';
import { getName, setName } from '../shared/meta.ts';
import { $internal, $repr, $resolve } from '../shared/symbols.ts';
Expand Down Expand Up @@ -34,17 +34,20 @@ export class AutoStruct implements BaseData, SelfResolvable {
#locations: Record<string, number> | undefined;
#cachedStruct: WgslStruct | undefined;
#typeForExtraProps: BaseData | undefined;
#options: IoSchemaOptions;

constructor(
validProps: Record<string, BaseData>,
typeForExtraProps: BaseData | undefined,
locations?: Record<string, number>,
options: IoSchemaOptions = {},
) {
this.#validProps = validProps;
this.#typeForExtraProps = typeForExtraProps;
this.#allocated = {};
this.#locations = locations;
this.#usedWgslKeys = new Set();
this.#options = options;
}

/**
Expand Down Expand Up @@ -97,6 +100,7 @@ export class AutoStruct implements BaseData, SelfResolvable {
}),
),
this.#locations,
this.#options,
);
const ownName = getName(this);
// Passing the given name forward
Expand Down
29 changes: 29 additions & 0 deletions packages/typegpu/tests/entryFnHeaderGen.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,33 @@ describe('autogenerating wgsl headers for tgpu entry functions with raw string W
}"
`);
});

it('marks integer vertex output varyings as flat', () => {
const vertex = tgpu.vertexFn({
out: {
pos: d.builtin.position,
flag: d.u32,
},
}) /* wgsl */ `{ return Out(vec4f(), 1u); }`;

const fragment = tgpu.fragmentFn({
in: { flag: d.u32 },
out: d.vec4f,
}) /* wgsl */ `{ return vec4f(f32(in.flag)); }`;

expect(tgpu.resolve([vertex, fragment])).toMatchInlineSnapshot(`
"struct vertex_Output {
@builtin(position) pos: vec4f,
@location(0) @interpolate(flat) flag: u32,
}

@vertex fn vertex() -> vertex_Output { return vertex_Output(vec4f(), 1u); }

struct fragment_Input {
@location(0) @interpolate(flat) flag: u32,
}

@fragment fn fragment(in: fragment_Input) -> @location(0) vec4f { return vec4f(f32(in.flag)); }"
`);
});
});
10 changes: 5 additions & 5 deletions packages/typegpu/tests/renderPipeline.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ describe('root.withVertex(...).withFragment(...)', () => {
@location(1) bar: vec3f,
@location(0) baz: vec3f,
@location(5) baz2: f32,
@location(3) baz3: u32,
@location(3) @interpolate(flat) baz3: u32,
@builtin(position) pos: vec4f,
}

Expand Down Expand Up @@ -320,13 +320,13 @@ describe('root.withVertex(...).withFragment(...)', () => {
@builtin(position) position: vec4f,
@location(0) baz: vec3f,
@location(5) baz2: f32,
@location(3) baz3: u32,
@location(3) @interpolate(flat) baz3: u32,
}

@vertex fn vertexMain() -> vertexMain_Output { return vertexMain_Output(); }

struct fragmentMain_Input {
@location(3) baz3: u32,
@location(3) @interpolate(flat) baz3: u32,
@location(1) bar: vec3f,
@location(2) foo: vec3f,
@location(5) baz2: f32,
Expand Down Expand Up @@ -1268,15 +1268,15 @@ describe('root.createRenderPipeline', () => {
expect(tgpu.resolve([pipeline])).toMatchInlineSnapshot(`
"struct VertexOut {
@builtin(position) position: vec4f,
@location(0) prop: i32,
@location(0) @interpolate(flat) prop: i32,
}

@vertex fn vertex() -> VertexOut {
return VertexOut(vec4f(), 0i);
}

struct FragmentIn {
@location(0) prop: i32,
@location(0) @interpolate(flat) prop: i32,
}

@fragment fn fragment(_arg_0: FragmentIn) -> @location(0) vec4f {
Expand Down