Skip to content

Commit 2c81dfa

Browse files
authored
impr: withPerformanceCallback on guardedComputePipeline (#2424)
1 parent 3bd330a commit 2c81dfa

3 files changed

Lines changed: 84 additions & 12 deletions

File tree

packages/typegpu/src/core/root/init.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,30 @@ export class TgpuGuardedComputePipelineImpl<
152152
);
153153
}
154154

155+
withPerformanceCallback(
156+
callback: (start: bigint, end: bigint) => void | Promise<void>,
157+
): TgpuGuardedComputePipeline<TArgs> {
158+
return new TgpuGuardedComputePipelineImpl(
159+
this.#root,
160+
this.#pipeline.withPerformanceCallback(callback),
161+
this.#sizeUniform,
162+
this.#workgroupSize,
163+
);
164+
}
165+
166+
withTimestampWrites(options: {
167+
querySet: TgpuQuerySet<'timestamp'> | GPUQuerySet;
168+
beginningOfPassWriteIndex?: number;
169+
endOfPassWriteIndex?: number;
170+
}): TgpuGuardedComputePipeline<TArgs> {
171+
return new TgpuGuardedComputePipelineImpl(
172+
this.#root,
173+
this.#pipeline.withTimestampWrites(options),
174+
this.#sizeUniform,
175+
this.#workgroupSize,
176+
);
177+
}
178+
155179
dispatchThreads(...threads: TArgs): void {
156180
const sanitizedSize = toVec3(threads);
157181
const workgroupCount = ceil(vec3f(sanitizedSize).div(vec3f(this.#workgroupSize)));

packages/typegpu/src/core/root/rootTypes.ts

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type { AnyComputeBuiltin, AnyFragmentInputBuiltin, OmitBuiltins } from '../../builtin.ts';
22
import type { TgpuQuerySet } from '../../core/querySet/querySet.ts';
33
import type { AnyData, Disarray, UndecorateRecord } from '../../data/dataTypes.ts';
4+
import type { InstanceToSchema } from '../../data/instanceToSchema.ts';
45
import type { WgslComparisonSamplerProps, WgslSamplerProps } from '../../data/sampler.ts';
56
import type {
67
AnyWgslData,
@@ -12,6 +13,7 @@ import type {
1213
Void,
1314
WgslArray,
1415
} from '../../data/wgslTypes.ts';
16+
import type { TgpuNamable } from '../../shared/meta.ts';
1517
import type {
1618
ExtractInvalidSchemaError,
1719
InferGPURecord,
@@ -33,7 +35,13 @@ import type { ShaderGenerator } from '../../tgsl/shaderGenerator.ts';
3335
import type { Unwrapper } from '../../unwrapper.ts';
3436
import type { TgpuBuffer, VertexFlag } from '../buffer/buffer.ts';
3537
import type { TgpuMutable, TgpuReadonly, TgpuUniform } from '../buffer/bufferShorthand.ts';
36-
import type { TgpuFixedComparisonSampler, TgpuFixedSampler } from '../sampler/sampler.ts';
38+
import type {
39+
AnyAutoCustoms,
40+
AutoFragmentIn,
41+
AutoFragmentOut,
42+
AutoVertexIn,
43+
AutoVertexOut,
44+
} from '../function/autoIO.ts';
3745
import type { IORecord } from '../function/fnTypes.ts';
3846
import type {
3947
FragmentInConstrained,
@@ -44,6 +52,7 @@ import type {
4452
import type { TgpuVertexFn } from '../function/tgpuVertexFn.ts';
4553
import type { TgpuComputePipeline } from '../pipeline/computePipeline.ts';
4654
import type { FragmentOutToTargets, TgpuRenderPipeline } from '../pipeline/renderPipeline.ts';
55+
import type { TgpuFixedComparisonSampler, TgpuFixedSampler } from '../sampler/sampler.ts';
4756
import type { Eventual, TgpuAccessor, TgpuMutableAccessor, TgpuSlot } from '../slot/slotTypes.ts';
4857
import type { TgpuTexture } from '../texture/texture.ts';
4958
import type {
@@ -52,15 +61,6 @@ import type {
5261
} from '../vertexLayout/vertexAttribute.ts';
5362
import type { TgpuVertexLayout } from '../vertexLayout/vertexLayout.ts';
5463
import type { TgpuComputeFn } from './../function/tgpuComputeFn.ts';
55-
import type { TgpuNamable } from '../../shared/meta.ts';
56-
import type {
57-
AnyAutoCustoms,
58-
AutoFragmentIn,
59-
AutoFragmentOut,
60-
AutoVertexIn,
61-
AutoVertexOut,
62-
} from '../function/autoIO.ts';
63-
import type { InstanceToSchema } from '../../data/instanceToSchema.ts';
6464

6565
// ----------
6666
// Public API
@@ -80,6 +80,24 @@ export interface TgpuGuardedComputePipeline<TArgs extends number[] = number[]> e
8080
*/
8181
with(encoder: GPUCommandEncoder): TgpuGuardedComputePipeline<TArgs>;
8282

83+
/**
84+
* Returns a pipeline wrapper with the given performance callback attached.
85+
* Analogous to `TgpuComputePipeline.withPerformanceCallback(callback)`.
86+
*/
87+
withPerformanceCallback(
88+
callback: (start: bigint, end: bigint) => void | Promise<void>,
89+
): TgpuGuardedComputePipeline<TArgs>;
90+
91+
/**
92+
* Returns a pipeline wrapper with the given timestamp writes configuration.
93+
* Analogous to `TgpuComputePipeline.withTimestampWrites(options)`.
94+
*/
95+
withTimestampWrites(options: {
96+
querySet: TgpuQuerySet<'timestamp'> | GPUQuerySet;
97+
beginningOfPassWriteIndex?: number;
98+
endOfPassWriteIndex?: number;
99+
}): TgpuGuardedComputePipeline<TArgs>;
100+
83101
/**
84102
* Dispatches the pipeline.
85103
* Unlike `TgpuComputePipeline.dispatchWorkgroups()`, this method takes in the
@@ -378,7 +396,7 @@ export interface WithBinding extends Withable<WithBinding> {
378396

379397
/**
380398
* Creates a compute pipeline that executes the given callback in an exact number of threads.
381-
* This is different from `withCompute(...).createPipeline()` in that it does a bounds check on the
399+
* This is different from `createComputePipeline()` in that it does a bounds check on the
382400
* thread id, where as regular pipelines do not and work in units of workgroups.
383401
*
384402
* @param callback A function converted to WGSL and executed on the GPU.

packages/typegpu/tests/guardedComputePipeline.test.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { describe, expect } from 'vitest';
1+
import { describe, expect, vi } from 'vitest';
22
import { it } from 'typegpu-testing-utility';
33
import { getName } from '../src/shared/meta.ts';
44
import { bindGroupLayout } from '../src/tgpuBindGroupLayout.ts';
@@ -31,4 +31,34 @@ describe('TgpuGuardedComputePipeline', () => {
3131
expect(getName(pipeline)).toBe('myPipeline');
3232
expect(getName(pipeline.pipeline)).toBe('myPipeline');
3333
});
34+
35+
it('delegates `withPerformanceCallback` to the underlying pipeline', ({ root }) => {
36+
const callback = vi.fn();
37+
const guarded = root.createGuardedComputePipeline(() => {
38+
'use gpu';
39+
});
40+
41+
const spy = vi.spyOn(guarded.pipeline, 'withPerformanceCallback');
42+
guarded.withPerformanceCallback(callback);
43+
44+
expect(spy).toHaveBeenCalledWith(callback);
45+
});
46+
47+
it('delegates `withTimestampWrites` to the underlying pipeline', ({ root }) => {
48+
const querySet = root.createQuerySet('timestamp', 2);
49+
const guarded = root.createGuardedComputePipeline(() => {
50+
'use gpu';
51+
});
52+
53+
const options = {
54+
querySet,
55+
beginningOfPassWriteIndex: 0,
56+
endOfPassWriteIndex: 1,
57+
};
58+
59+
const spy = vi.spyOn(guarded.pipeline, 'withTimestampWrites');
60+
guarded.withTimestampWrites(options);
61+
62+
expect(spy).toHaveBeenCalledWith(options);
63+
});
3464
});

0 commit comments

Comments
 (0)