diff --git a/packages/typegpu/src/core/function/autoIO.ts b/packages/typegpu/src/core/function/autoIO.ts index b35a092b1a..54e222f904 100644 --- a/packages/typegpu/src/core/function/autoIO.ts +++ b/packages/typegpu/src/core/function/autoIO.ts @@ -86,7 +86,9 @@ export class AutoFragmentFn implements SelfResolvable { setName(impl, 'fragmentFn'); } this.#core = createFnCore(impl, 'fragment'); - this.autoIn = new AutoStruct({ ...builtinFragmentIn, ...varyings }, undefined, locations); + this.autoIn = new AutoStruct({ ...builtinFragmentIn, ...varyings }, undefined, locations, { + autoInterpolateIntegerVaryings: true, + }); setName(this.autoIn, 'FragmentIn'); this.autoOut = new AutoStruct(builtinFragmentOut, vec4f); setName(this.autoOut, 'FragmentOut'); @@ -134,7 +136,9 @@ export class AutoVertexFn implements SelfResolvable { this.#core = createFnCore(impl, 'vertex'); this.autoIn = new AutoStruct({ ...builtinVertexIn, ...attribs }, undefined, locations); setName(this.autoIn, 'VertexIn'); - this.autoOut = new AutoStruct(builtinVertexOut, undefined); + this.autoOut = new AutoStruct(builtinVertexOut, undefined, undefined, { + autoInterpolateIntegerVaryings: true, + }); setName(this.autoOut, 'VertexOut'); } diff --git a/packages/typegpu/src/core/function/ioSchema.ts b/packages/typegpu/src/core/function/ioSchema.ts index 9d3fe80726..f7e94f4ddc 100644 --- a/packages/typegpu/src/core/function/ioSchema.ts +++ b/packages/typegpu/src/core/function/ioSchema.ts @@ -2,12 +2,20 @@ import { type Decorate, type HasCustomLocation, type IsBuiltin, + interpolate, location, } from '../../data/attributes.ts'; import { isBuiltin } from '../../data/attributes.ts'; -import { getCustomLocation, isData } from '../../data/dataTypes.ts'; +import { getCustomLocation, isData, undecorate } from '../../data/dataTypes.ts'; import { INTERNAL_createStruct } from '../../data/struct.ts'; -import { type BaseData, isVoid, type Location, type WgslStruct } from '../../data/wgslTypes.ts'; +import { + type BaseData, + type FlatInterpolatableData, + isInterpolateAttrib, + isVoid, + type Location, + type WgslStruct, +} from '../../data/wgslTypes.ts'; import type { SeparatedEntryArgs } from './fnTypes.ts'; export type WithLocations> = { @@ -28,9 +36,39 @@ export type IOLayoutToSchema = T extends BaseData ? void : never; +const integerVaryingTypes = new Set([ + 'i32', + 'u32', + 'vec2i', + 'vec2u', + 'vec3i', + 'vec3u', + 'vec4i', + 'vec4u', +]); + +export type IoSchemaOptions = { + readonly autoInterpolateIntegerVaryings?: boolean; +}; + +function hasInterpolation(data: BaseData) { + return (data as { attribs?: unknown[] }).attribs?.some(isInterpolateAttrib) ?? false; +} + +function maybeInterpolateIntegerVarying(data: BaseData, options: IoSchemaOptions) { + if (!options.autoInterpolateIntegerVaryings || hasInterpolation(data)) { + return data; + } + + return integerVaryingTypes.has(undecorate(data).type) + ? interpolate('flat', data as FlatInterpolatableData) + : data; +} + export function withLocations( members: Record | undefined, locations: Record = {}, + options: IoSchemaOptions = {}, ): Record { let nextLocation = 0; const usedCustomLocations = new Set(); @@ -54,21 +92,22 @@ export function withLocations( // skipping builtins return [key, member]; } + const memberWithInterpolation = maybeInterpolateIntegerVarying(member, options); if (getCustomLocation(member) !== undefined) { // this member is already marked - return [key, member]; + return [key, memberWithInterpolation]; } if (locations[key]) { // location has been determined by a previous procedure - return [key, location(locations[key], member)]; + return [key, location(locations[key], memberWithInterpolation)]; } while (usedCustomLocations.has(nextLocation)) { nextLocation++; } - return [key, location(nextLocation++, member)]; + return [key, location(nextLocation++, memberWithInterpolation)]; }), ); } @@ -76,6 +115,7 @@ export function withLocations( export function separateBuiltins( schema: Record, locations: Record = {}, + options: IoSchemaOptions = {}, ): SeparatedEntryArgs { const positionalArgs: SeparatedEntryArgs['positionalArgs'] = []; const dataFields: Record = {}; @@ -90,7 +130,7 @@ export function separateBuiltins( const dataSchema = Object.keys(dataFields).length > 0 - ? INTERNAL_createStruct(withLocations(dataFields, locations), /* isAbstruct */ false) + ? INTERNAL_createStruct(withLocations(dataFields, locations, options), /* isAbstruct */ false) : undefined; return { dataSchema, positionalArgs }; @@ -105,6 +145,7 @@ export function separateAllAsPositional(schema: Record): Separ export function createIoSchema>( layout: T, locations: Record = {}, + options: IoSchemaOptions = {}, ) { return ( isData(layout) @@ -116,7 +157,7 @@ export function createIoSchema>( ? layout : location(0, layout) : INTERNAL_createStruct( - withLocations(layout as Record, locations), + withLocations(layout as Record, locations, options), /* isAbstruct */ false, ) ) as IOLayoutToSchema; diff --git a/packages/typegpu/src/core/function/tgpuFragmentFn.ts b/packages/typegpu/src/core/function/tgpuFragmentFn.ts index cb3b4c990f..6ff1c55b2a 100644 --- a/packages/typegpu/src/core/function/tgpuFragmentFn.ts +++ b/packages/typegpu/src/core/function/tgpuFragmentFn.ts @@ -211,7 +211,9 @@ function createFragmentFn( }, [$resolve](ctx: ResolutionCtx): ResolvedSnippet { - const entryInput = separateBuiltins(shell.in ?? {}, ctx.varyingLocations ?? {}); + const entryInput = separateBuiltins(shell.in ?? {}, ctx.varyingLocations ?? {}, { + autoInterpolateIntegerVaryings: true, + }); if (entryInput.dataSchema && isNamable(entryInput.dataSchema)) { entryInput.dataSchema.$name(`${getName(this) ?? ''}_Input`); diff --git a/packages/typegpu/src/core/function/tgpuVertexFn.ts b/packages/typegpu/src/core/function/tgpuVertexFn.ts index 8c5b230891..0a5d286a2b 100644 --- a/packages/typegpu/src/core/function/tgpuVertexFn.ts +++ b/packages/typegpu/src/core/function/tgpuVertexFn.ts @@ -177,9 +177,9 @@ function createVertexFn( }, [$resolve](ctx: ResolutionCtx): ResolvedSnippet { - const outputWithLocation = createIoSchema(shell.out, ctx.varyingLocations).$name( - `${getName(this) ?? ''}_Output`, - ); + const outputWithLocation = createIoSchema(shell.out, ctx.varyingLocations, { + autoInterpolateIntegerVaryings: true, + }).$name(`${getName(this) ?? ''}_Output`); if (typeof implementation === 'string') { core.applyExternals({ Out: outputWithLocation }); diff --git a/packages/typegpu/src/data/autoStruct.ts b/packages/typegpu/src/data/autoStruct.ts index 4fc745cccb..94195248a8 100644 --- a/packages/typegpu/src/data/autoStruct.ts +++ b/packages/typegpu/src/data/autoStruct.ts @@ -1,4 +1,4 @@ -import { createIoSchema } from '../core/function/ioSchema.ts'; +import { createIoSchema, type IoSchemaOptions } from '../core/function/ioSchema.ts'; import { validateProp } from '../nameUtils.ts'; import { getName, setName } from '../shared/meta.ts'; import { $internal, $repr, $resolve } from '../shared/symbols.ts'; @@ -34,17 +34,20 @@ export class AutoStruct implements BaseData, SelfResolvable { #locations: Record | undefined; #cachedStruct: WgslStruct | undefined; #typeForExtraProps: BaseData | undefined; + #options: IoSchemaOptions; constructor( validProps: Record, typeForExtraProps: BaseData | undefined, locations?: Record, + options: IoSchemaOptions = {}, ) { this.#validProps = validProps; this.#typeForExtraProps = typeForExtraProps; this.#allocated = {}; this.#locations = locations; this.#usedWgslKeys = new Set(); + this.#options = options; } /** @@ -97,6 +100,7 @@ export class AutoStruct implements BaseData, SelfResolvable { }), ), this.#locations, + this.#options, ); const ownName = getName(this); // Passing the given name forward diff --git a/packages/typegpu/tests/entryFnHeaderGen.test.ts b/packages/typegpu/tests/entryFnHeaderGen.test.ts index a0b9abfa69..ed47c60ed3 100644 --- a/packages/typegpu/tests/entryFnHeaderGen.test.ts +++ b/packages/typegpu/tests/entryFnHeaderGen.test.ts @@ -123,4 +123,33 @@ describe('autogenerating wgsl headers for tgpu entry functions with raw string W }" `); }); + + it('marks integer vertex output varyings as flat', () => { + const vertex = tgpu.vertexFn({ + out: { + pos: d.builtin.position, + flag: d.u32, + }, + }) /* wgsl */ `{ return Out(vec4f(), 1u); }`; + + const fragment = tgpu.fragmentFn({ + in: { flag: d.u32 }, + out: d.vec4f, + }) /* wgsl */ `{ return vec4f(f32(in.flag)); }`; + + expect(tgpu.resolve([vertex, fragment])).toMatchInlineSnapshot(` + "struct vertex_Output { + @builtin(position) pos: vec4f, + @location(0) @interpolate(flat) flag: u32, + } + + @vertex fn vertex() -> vertex_Output { return vertex_Output(vec4f(), 1u); } + + struct fragment_Input { + @location(0) @interpolate(flat) flag: u32, + } + + @fragment fn fragment(in: fragment_Input) -> @location(0) vec4f { return vec4f(f32(in.flag)); }" + `); + }); }); diff --git a/packages/typegpu/tests/renderPipeline.test.ts b/packages/typegpu/tests/renderPipeline.test.ts index 8c10b79f5b..a3f5c7dad4 100644 --- a/packages/typegpu/tests/renderPipeline.test.ts +++ b/packages/typegpu/tests/renderPipeline.test.ts @@ -269,7 +269,7 @@ describe('root.withVertex(...).withFragment(...)', () => { @location(1) bar: vec3f, @location(0) baz: vec3f, @location(5) baz2: f32, - @location(3) baz3: u32, + @location(3) @interpolate(flat) baz3: u32, @builtin(position) pos: vec4f, } @@ -320,13 +320,13 @@ describe('root.withVertex(...).withFragment(...)', () => { @builtin(position) position: vec4f, @location(0) baz: vec3f, @location(5) baz2: f32, - @location(3) baz3: u32, + @location(3) @interpolate(flat) baz3: u32, } @vertex fn vertexMain() -> vertexMain_Output { return vertexMain_Output(); } struct fragmentMain_Input { - @location(3) baz3: u32, + @location(3) @interpolate(flat) baz3: u32, @location(1) bar: vec3f, @location(2) foo: vec3f, @location(5) baz2: f32, @@ -1268,7 +1268,7 @@ describe('root.createRenderPipeline', () => { expect(tgpu.resolve([pipeline])).toMatchInlineSnapshot(` "struct VertexOut { @builtin(position) position: vec4f, - @location(0) prop: i32, + @location(0) @interpolate(flat) prop: i32, } @vertex fn vertex() -> VertexOut { @@ -1276,7 +1276,7 @@ describe('root.createRenderPipeline', () => { } struct FragmentIn { - @location(0) prop: i32, + @location(0) @interpolate(flat) prop: i32, } @fragment fn fragment(_arg_0: FragmentIn) -> @location(0) vec4f {