diff --git a/package-lock.json b/package-lock.json index 06bc8ec3e..1b58625b2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -25,7 +25,7 @@ "rollup": "^2.70.0", "simple-git": "^3.10.0", "three": "^0.183.1", - "three-mesh-bvh": "^0.9.5", + "three-mesh-bvh": "^0.9.8", "typescript": "^5.9.2", "vite": "^6.2.2", "yargs": "^17.5.1" @@ -3657,9 +3657,9 @@ "license": "MIT" }, "node_modules/three-mesh-bvh": { - "version": "0.9.5", - "resolved": "https://registry.npmjs.org/three-mesh-bvh/-/three-mesh-bvh-0.9.5.tgz", - "integrity": "sha512-MYpwzUWDxPAKGhSBFin9E/7K4AAHyIm4IfMZQ/3+Z/jq/swa2dAhXx0yUNDd9mjlhLuzXkMBTGDZioL2GSlIfQ==", + "version": "0.9.8", + "resolved": "https://registry.npmjs.org/three-mesh-bvh/-/three-mesh-bvh-0.9.8.tgz", + "integrity": "sha512-YphYvdXEZSXdz6iNdWJo1RB6qvSCRyiXPEVSvNU6xVWbLDOdSrfEIsJOpgFOnefdmVEvZ6M+sY0cjh9gl7MvdA==", "dev": true, "license": "MIT", "peerDependencies": { @@ -6432,9 +6432,9 @@ "dev": true }, "three-mesh-bvh": { - "version": "0.9.5", - "resolved": "https://registry.npmjs.org/three-mesh-bvh/-/three-mesh-bvh-0.9.5.tgz", - "integrity": "sha512-MYpwzUWDxPAKGhSBFin9E/7K4AAHyIm4IfMZQ/3+Z/jq/swa2dAhXx0yUNDd9mjlhLuzXkMBTGDZioL2GSlIfQ==", + "version": "0.9.8", + "resolved": "https://registry.npmjs.org/three-mesh-bvh/-/three-mesh-bvh-0.9.8.tgz", + "integrity": "sha512-YphYvdXEZSXdz6iNdWJo1RB6qvSCRyiXPEVSvNU6xVWbLDOdSrfEIsJOpgFOnefdmVEvZ6M+sY0cjh9gl7MvdA==", "dev": true, "requires": {} }, diff --git a/package.json b/package.json index a8360ea92..46b0f3d4b 100644 --- a/package.json +++ b/package.json @@ -46,7 +46,7 @@ "rollup": "^2.70.0", "simple-git": "^3.10.0", "three": "^0.183.1", - "three-mesh-bvh": "^0.9.5", + "three-mesh-bvh": "^0.9.8", "typescript": "^5.9.2", "vite": "^6.2.2", "yargs": "^17.5.1" diff --git a/src/webgpu/MegaKernelPathTracer.js b/src/webgpu/MegaKernelPathTracer.js index 1a390d6dc..82b1c6e03 100644 --- a/src/webgpu/MegaKernelPathTracer.js +++ b/src/webgpu/MegaKernelPathTracer.js @@ -8,7 +8,6 @@ function* renderTask() { renderer, camera, kernel, - geometry, bounces, tiles, @@ -22,13 +21,6 @@ function* renderTask() { kernel.outputTarget = outputTarget; kernel.sampleCountTarget = sampleCountTarget; - kernel.geom_index = geometry.index; - kernel.geom_position = geometry.position; - kernel.geom_normals = geometry.normal; - kernel.geom_material_index = geometry.materialIndex; - kernel.bvh = geometry.bvh; - kernel.materials = geometry.materials; - kernel.bounces = bounces; kernel.inverseProjectionMatrix.copy( camera.projectionMatrixInverse ); kernel.cameraToModelMatrix.copy( camera.matrixWorld ); @@ -77,17 +69,6 @@ export class MegaKernelPathTracer { this.bounces = 7; this.tiles = new Vector2( 2, 2 ); - // geometry fields - this.geometry = { - bvh: null, - index: null, - position: null, - normal: null, - - materialIndex: null, - materials: null, - }; - // targets this.outputTarget = new StorageTexture( 1, 1, ); this.outputTarget.format = RGBAFormat; @@ -118,23 +99,11 @@ export class MegaKernelPathTracer { } - setGeometryData( geometry ) { - - for ( const propName in geometry ) { - - const prop = this.geometry[ propName ]; - if ( prop === undefined ) { - - console.error( `Invalid property name in geometry data: ${propName}` ); - continue; + setBVHData( bvhData ) { - } - - // TODO: cannot dispose at the moment - // prop.dispose(); - this.geometry[ propName ] = geometry[ propName ]; - - } + this.kernel.bvhData = bvhData; + this.kernel.needsUpdate = true; + this.reset(); } @@ -223,7 +192,6 @@ export class MegaKernelPathTracer { this.samples = 0; this._task = null; - const { width, height } = sampleCountTarget; const dispatchSize = sampleCountClearKernel.getDispatchSize( width, height ); @@ -240,7 +208,7 @@ export class MegaKernelPathTracer { update() { - if ( ! this.camera ) { + if ( ! this.camera || ! this.kernel ) { return; diff --git a/src/webgpu/WaveFrontPathTracer.js b/src/webgpu/WaveFrontPathTracer.js index ee9a8a2f0..5aad3a6a6 100644 --- a/src/webgpu/WaveFrontPathTracer.js +++ b/src/webgpu/WaveFrontPathTracer.js @@ -225,23 +225,15 @@ export class WaveFrontPathTracer { } - setGeometryData( geometry ) { + setBVHData( bvhData ) { - for ( const propName in geometry ) { + this.rayIntersectionKernel.bvhData = bvhData; + this.rayIntersectionKernel.needsUpdate = true; - const prop = this.geometry[ propName ]; - if ( prop === undefined ) { + this.hitProcessKernel.bvhData = bvhData; + this.hitProcessKernel.needsUpdate = true; - console.error( `Invalid property name in geometry data: ${propName}` ); - continue; - - } - - // TODO: cannot dispose at the moment - // prop.dispose(); - this.geometry[ propName ] = geometry[ propName ]; - - } + this.reset(); } diff --git a/src/webgpu/WebGPUPathTracer.js b/src/webgpu/WebGPUPathTracer.js index a38ddeae4..f592a513b 100644 --- a/src/webgpu/WebGPUPathTracer.js +++ b/src/webgpu/WebGPUPathTracer.js @@ -1,9 +1,11 @@ -import { Color, StorageBufferAttribute, PerspectiveCamera, Scene, Vector2, Clock } from 'three/webgpu'; -import { PathTracingSceneGenerator } from '../core/PathTracingSceneGenerator.js'; +import { Vector2, Scene, PerspectiveCamera } from 'three/webgpu'; +import { MeshBVH, SAH } from 'three-mesh-bvh'; import { FullScreenQuad } from 'three/examples/jsm/postprocessing/Pass.js'; import { RenderToScreenNodeMaterial } from './materials/RenderToScreenMaterial.js'; import { MegaKernelPathTracer } from './MegaKernelPathTracer.js'; import { WaveFrontPathTracer } from './WaveFrontPathTracer.js'; +import { ObjectBVH } from './lib/ObjectBVH.js'; +import { PathtracerBVHComputeData } from './nodes/PathtracerBVHComputeData.js'; const _resolution = new Vector2(); export class WebGPUPathTracer { @@ -24,7 +26,8 @@ export class WebGPUPathTracer { this._pathTracer.dispose(); this._pathTracer = value ? new MegaKernelPathTracer( this._renderer ) : new WaveFrontPathTracer( this._renderer ); - this._generator = new PathTracingSceneGenerator(); + this._pathTracer.setBVHData( this._bvhData ); + this.setCamera( this.camera ); } @@ -32,11 +35,7 @@ export class WebGPUPathTracer { // members this._renderer = renderer; - this._generator = new PathTracingSceneGenerator(); - // this._pathTracer = new MegaKernelPathTracer( renderer ); - this._pathTracer = new WaveFrontPathTracer( renderer ); - this._queueReset = false; - this._clock = new Clock(); + this._pathTracer = new MegaKernelPathTracer( renderer ); // options this.renderScale = 1; @@ -54,11 +53,25 @@ export class WebGPUPathTracer { scene.updateMatrixWorld( true ); camera.updateMatrixWorld(); - const generator = this._generator; - generator.setObjects( scene ); + // Build BVH for each mesh geometry + scene.traverse( child => { - const result = generator.generate(); - return this._updateFromResults( scene, camera, result ); + if ( child.isMesh && ! child.geometry.boundsTree ) { + + child.geometry.boundsTree = new MeshBVH( child.geometry, { strategy: SAH, maxLeafSize: 5 } ); + + } + + } ); + + // Build TLAS and compute functions + const objectBVH = new ObjectBVH( scene, { strategy: SAH } ); + const bvhData = new PathtracerBVHComputeData( objectBVH ); + bvhData.update(); + + this._bvhData = bvhData; + this._pathTracer.setBVHData( bvhData ); + this.setCamera( camera ); } @@ -85,80 +98,6 @@ export class WebGPUPathTracer { } - _updateFromResults( scene, camera, results ) { - - const { - materials, - geometry, - bvh, - bvhChanged, - needsMaterialIndexUpdate, - } = results; - - const pathTracer = this._pathTracer; - - const newGeometryData = {}; - - if ( bvhChanged ) { - - // dereference a new index attribute if we're using indirect storage - const dereferencedIndexAttr = geometry.index.clone(); - const indirectBuffer = bvh._indirectBuffer; - if ( indirectBuffer ) { - - dereferenceIndex( geometry, indirectBuffer, dereferencedIndexAttr ); - - } - - const newIndex = new StorageBufferAttribute( dereferencedIndexAttr.array, 3 ); - newIndex.name = 'Geometry Index'; - newGeometryData.index = newIndex; - - const newPosition = new StorageBufferAttribute( geometry.attributes.position.array, 3 ); - newPosition.name = 'Geometry Positions'; - newGeometryData.position = newPosition; - - const newNormals = new StorageBufferAttribute( geometry.attributes.normal.array, 3 ); - newNormals.name = 'Geometry Normals'; - newGeometryData.normal = newNormals; - - const newBvhRoots = new StorageBufferAttribute( new Float32Array( bvh._roots[ 0 ] ), 8 ); - newBvhRoots.name = 'BVH Roots'; - newGeometryData.bvh = newBvhRoots; - - } - - if ( needsMaterialIndexUpdate ) { - - const newMaterialIndex = new StorageBufferAttribute( geometry.attributes.materialIndex.array, 1 ); - newMaterialIndex.name = 'Material Index'; - newGeometryData.materialIndex = newMaterialIndex; - - } - - const newMaterialsData = new Float32Array( materials.length * 3 ); - const defaultColor = new Color(); - for ( let i = 0; i < materials.length; i ++ ) { - - const material = materials[ i ]; - const color = material.color ?? defaultColor; - // Make sure those are in linear-sRGB space - newMaterialsData[ 3 * i + 0 ] = color.r; - newMaterialsData[ 3 * i + 1 ] = color.g; - newMaterialsData[ 3 * i + 2 ] = color.b; - - } - - const newMaterialsBuffer = new StorageBufferAttribute( newMaterialsData, 3 ); - newMaterialsBuffer.name = 'Material Data'; - newGeometryData.materials = newMaterialsBuffer; - - pathTracer.setGeometryData( newGeometryData ); - - this.setCamera( camera ); - - } - renderSample() { if ( ! this._renderer._initialized ) { @@ -216,22 +155,3 @@ export class WebGPUPathTracer { } } - -// TODO: Expose in three-mesh-bvh? -function dereferenceIndex( geometry, indirectBuffer, target ) { - - const unpacked = target.array; - const indexArray = geometry.index ? geometry.index.array : null; - for ( let i = 0, l = indirectBuffer.length; i < l; i ++ ) { - - const i3 = 3 * i; - const v3 = 3 * indirectBuffer[ i ]; - for ( let c = 0; c < 3; c ++ ) { - - unpacked[ i3 + c ] = indexArray ? indexArray[ v3 + c ] : v3 + c; - - } - - } - -} diff --git a/src/webgpu/compute/ComputeKernel.js b/src/webgpu/compute/ComputeKernel.js index 520f475a2..81ba08eec 100644 --- a/src/webgpu/compute/ComputeKernel.js +++ b/src/webgpu/compute/ComputeKernel.js @@ -12,6 +12,13 @@ export class ComputeKernel { } + set needsUpdate( v ) { + + // TODO: hack to force the kernel to rebuild since "needsUpdate" is not respected + this.setWorkgroupSize( ...this.workgroupSize ); + + } + constructor( fn, options = {} ) { const { @@ -60,7 +67,6 @@ export class ComputeKernel { setWorkgroupSize( x = 64, y = 1, z = 1 ) { - // this.workgroupSize = [ x, y, z ]; this.kernel = this._fn.computeKernel( [ x, y, z ] ); return this; diff --git a/src/webgpu/compute/PathTracerMegaKernel.js b/src/webgpu/compute/PathTracerMegaKernel.js index 90eca4eb9..cea76dd40 100644 --- a/src/webgpu/compute/PathTracerMegaKernel.js +++ b/src/webgpu/compute/PathTracerMegaKernel.js @@ -1,20 +1,24 @@ -import { IndirectStorageBufferAttribute, Matrix4, Vector2, StorageTexture } from 'three/webgpu'; +import { Matrix4, Vector2, StorageTexture } from 'three/webgpu'; +import { ndcToCameraRay } from '../lib/wgsl/common.wgsl.js'; import { ComputeKernel } from './ComputeKernel.js'; -import { uniform, storage, globalId, textureStore } from 'three/tsl'; -import megakernelShader from '../nodes/megakernel.wgsl.js'; +import { uniform, globalId, textureStore, wgslFn } from 'three/tsl'; +import { pcgRand3, pcgInit } from '../nodes/random.wgsl.js'; +import { lambertBsdfFunc } from '../nodes/sampling.wgsl.js'; +import { proxy } from '../lib/nodes/NodeProxy.js'; export class PathTracerMegaKernel extends ComputeKernel { constructor() { - const megakernelShaderParams = { + const parameters = { + bvhData: { value: null }, + prevOutputTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadOnly(), outputTarget: textureStore( new StorageTexture( 1, 1 ) ).toWriteOnly(), sampleCountTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadWrite(), offset: uniform( new Vector2() ), tileSize: uniform( new Vector2() ), - smoothNormals: uniform( 1 ), seed: uniform( 0 ), bounces: uniform( 5 ), @@ -22,22 +26,109 @@ export class PathTracerMegaKernel extends ComputeKernel { inverseProjectionMatrix: uniform( new Matrix4() ), cameraToModelMatrix: uniform( new Matrix4() ), - // bvh and geometry definition - geom_index: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3u' ).toReadOnly(), - geom_position: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3f' ).toReadOnly(), - geom_normals: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3f' ).toReadOnly(), - geom_material_index: storage( new IndirectStorageBufferAttribute( 1, 1 ), 'u32' ).toReadOnly(), - bvh: storage( new IndirectStorageBufferAttribute(), 'BVHNode' ).toReadOnly(), // TODO: fill this in - - materials: storage( new IndirectStorageBufferAttribute(), 'Material' ).toReadOnly(), // TODO: fill this in - // compute variables globalId: globalId, }; - super( megakernelShader( megakernelShaderParams ) ); + const shader = wgslFn( /* wgsl */` + + fn compute( + + // indices and target + globalId: vec3u, + prevOutputTarget: texture_storage_2d, + outputTarget: texture_storage_2d, + sampleCountTarget: texture_storage_2d, + + // tiles + offset: vec2u, + tileSize: vec2u, + + // settings + inverseProjectionMatrix: mat4x4f, + cameraToModelMatrix: mat4x4f, + seed: u32, + bounces: u32, + + ) -> void { + + // make sure we don't bleed over the edge of our tile + if ( globalId.x >= tileSize.x || globalId.y >= tileSize.y ) { + + return; + + } + + // to screen coordinates + let indexUV = offset + globalId.xy; + let targetDimensions = textureDimensions( outputTarget ); + if ( indexUV.x >= targetDimensions.x || indexUV.y >= targetDimensions.y ) { + + return; + + } + + let uv = vec2f( indexUV ) / vec2f( targetDimensions ); + let ndc = uv * 2.0 - vec2f( 1.0 ); + + pcgInitialize( indexUV, seed ); + + // scene ray + var jitter = 2.0 * ( pcgRand2() - vec2( 0.5 ) ) / vec2f( targetDimensions.xy ); + var ray = ndcToCameraRay( ndc + jitter, cameraToModelMatrix * inverseProjectionMatrix ); + + var resultColor = vec3f( 0.0 ); + var throughputColor = vec3f( 1.0 ); + + for ( var bounce = 0u; bounce < bounces; bounce ++ ) { + + let hitResult = bvh_RaycastFirstHit( ray ); + if ( hitResult.didHit ) { + + let vertexData = bvh_sampleTrianglePoint( hitResult.barycoord, hitResult.indices.xyz ); + let hitPosition = ray.origin + ray.direction * hitResult.dist; + let scatterRec = bsdfEval( normalize( vertexData.normal.xyz ), - ray.direction ); + + let transform = bvh_transforms.value[ hitResult.objectIndex ]; + let material = bvh_materials.value[ transform.materialIndex ]; + + // white diffuse surface + throughputColor *= material.albedo * scatterRec.value / scatterRec.pdf; + + ray.origin = hitPosition; + ray.direction = scatterRec.direction; + + } else { + + let background = vec3f( 0.5 ); + resultColor += background * throughputColor; + break; + + } + + } + + let sampleCount = textureLoad( sampleCountTarget, indexUV ).r + 1; + var color = textureLoad( prevOutputTarget, indexUV ).xyz; + color += ( resultColor - color.xyz ) / f32( sampleCount ); + + textureStore( sampleCountTarget, indexUV, vec4( sampleCount ) ); + textureStore( outputTarget, indexUV, vec4( color, 1.0 ) ); + + } + + `, [ + proxy( 'bvhData.value.storage.materials', parameters ), + proxy( 'bvhData.value.structs.material', parameters ), + proxy( 'bvhData.value.structs.transform', parameters ), + proxy( 'bvhData.value.fns.raycastFirstHit', parameters ), + proxy( 'bvhData.value.fns.sampleTrianglePoint', parameters ), + ndcToCameraRay, pcgRand3, pcgInit, lambertBsdfFunc, + ] ); + + super( shader( parameters ) ); - this.defineUniformAccessors( megakernelShaderParams ); + this.defineUniformAccessors( parameters ); } diff --git a/src/webgpu/compute/wavefront/ProcessHitsKernel.js b/src/webgpu/compute/wavefront/ProcessHitsKernel.js index ebe21b62c..688c46fa8 100644 --- a/src/webgpu/compute/wavefront/ProcessHitsKernel.js +++ b/src/webgpu/compute/wavefront/ProcessHitsKernel.js @@ -1,17 +1,18 @@ import { IndirectStorageBufferAttribute, StorageTexture } from 'three/webgpu'; import { ComputeKernel } from '../ComputeKernel.js'; import { uniform, storage, wgslFn, textureStore, globalId } from 'three/tsl'; -import { constants, getVertexAttribute } from 'three-mesh-bvh/webgpu'; import { pcgRand3, pcgInit } from '../../nodes/random.wgsl.js'; -import { materialStruct } from '../../nodes/structs.wgsl.js'; import { lambertBsdfFunc } from '../../nodes/sampling.wgsl.js'; import { queuedRayStruct, queuedHitStruct, QUEUED_RAY_SIZE, QUEUED_HIT_SIZE } from './structs.js'; +import { proxy } from '../../lib/nodes/NodeProxy.js'; export class ProcessHitsKernel extends ComputeKernel { constructor() { const parameters = { + bvhData: { value: null }, + prevOutputTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadOnly(), outputTarget: textureStore( new StorageTexture( 1, 1 ) ).toWriteOnly(), sampleCountTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadWrite(), @@ -27,11 +28,6 @@ export class ProcessHitsKernel extends ComputeKernel { hitQueue: storage( new IndirectStorageBufferAttribute( 1, QUEUED_HIT_SIZE ), 'QueuedHit' ), hitQueueSize: storage( new IndirectStorageBufferAttribute( 2, 1 ), 'u32' ), - // bvh and geometry definition - geom_position: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3f' ).toReadOnly(), - geom_normals: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3f' ).toReadOnly(), - materials: storage( new IndirectStorageBufferAttribute(), 'Material' ).toReadOnly(), // TODO: fill in initial values - globalId: globalId, }; @@ -55,11 +51,6 @@ export class ProcessHitsKernel extends ComputeKernel { hitQueue: ptr, read_write>, hitQueueSize: ptr, read_write>, - // scene - geom_position: ptr, read>, - geom_normals: ptr, read>, - materials: ptr, read>, - globalId: vec3u ) -> void { @@ -80,10 +71,13 @@ export class ProcessHitsKernel extends ComputeKernel { pcgInitialize( indexUV, seed ); - let material = materials[ input.materialIndex ]; - let hitPosition = getVertexAttribute( input.barycoord, input.indices.xyz, geom_position ); - let hitNormal = getVertexAttribute( input.barycoord, input.indices.xyz, geom_normals ); - let scatterRec = bsdfEval( hitNormal, input.view ); + let object = bvh_transforms.value[ input.objectIndex ]; + let material = bvh_materials.value[ object.materialIndex ]; + var vertexData = bvh_sampleTrianglePoint( input.barycoord, input.indices.xyz ); + vertexData.normal = normalize( transpose( object.inverseMatrixWorld ) * vertexData.normal ); + vertexData.position = object.matrixWorld * vertexData.position; + + let scatterRec = bsdfEval( vertexData.normal.xyz, input.view ); if ( input.currentBounce >= bounces ) { @@ -99,7 +93,7 @@ export class ProcessHitsKernel extends ComputeKernel { let rayQueueCapacity = arrayLength( rayQueue ); let index = atomicAdd( &rayQueueSize[ 1 ], 1 ) % rayQueueCapacity; - rayQueue[ index ].ray.origin = hitPosition; + rayQueue[ index ].ray.origin = vertexData.position.xyz; rayQueue[ index ].ray.direction = scatterRec.direction; rayQueue[ index ].pixel = indexUV; rayQueue[ index ].throughputColor = input.throughputColor * material.albedo * scatterRec.value / scatterRec.pdf; @@ -108,7 +102,15 @@ export class ProcessHitsKernel extends ComputeKernel { } } - `, [ queuedRayStruct, lambertBsdfFunc, constants, getVertexAttribute, pcgRand3, pcgInit, queuedHitStruct, materialStruct ] ); + `, [ + proxy( 'bvhData.value.structs.material', parameters ), + proxy( 'bvhData.value.structs.transform', parameters ), + proxy( 'bvhData.value.storage.materials', parameters ), + proxy( 'bvhData.value.storage.transforms', parameters ), + proxy( 'bvhData.value.fns.sampleTrianglePoint', parameters ), + queuedRayStruct, lambertBsdfFunc, + pcgRand3, pcgInit, queuedHitStruct, + ] ); super( fn( parameters ) ); diff --git a/src/webgpu/compute/wavefront/RayGenerationKernel.js b/src/webgpu/compute/wavefront/RayGenerationKernel.js index 00ee6a873..696103f0d 100644 --- a/src/webgpu/compute/wavefront/RayGenerationKernel.js +++ b/src/webgpu/compute/wavefront/RayGenerationKernel.js @@ -2,7 +2,7 @@ import { Vector2, Matrix4 } from 'three'; import { IndirectStorageBufferAttribute, StorageTexture } from 'three/webgpu'; import { wgslFn, uniform, storage, globalId, textureStore } from 'three/tsl'; import { ComputeKernel } from '../ComputeKernel.js'; -import { ndcToCameraRay } from 'three-mesh-bvh/webgpu'; +import { ndcToCameraRay } from '../../lib/wgsl/common.wgsl.js'; import { pcgInit, pcgRand2 } from '../../nodes/random.wgsl.js'; import { QUEUED_RAY_SIZE, queuedRayStruct } from './structs.js'; diff --git a/src/webgpu/compute/wavefront/RayIntersectionKernel.js b/src/webgpu/compute/wavefront/RayIntersectionKernel.js index 086caa575..1c0e28981 100644 --- a/src/webgpu/compute/wavefront/RayIntersectionKernel.js +++ b/src/webgpu/compute/wavefront/RayIntersectionKernel.js @@ -1,15 +1,17 @@ import { IndirectStorageBufferAttribute, StorageTexture } from 'three/webgpu'; import { ComputeKernel } from '../ComputeKernel.js'; import { storage, wgslFn, textureStore, globalId } from 'three/tsl'; -import { bvhIntersectFirstHit, constants } from 'three-mesh-bvh/webgpu'; import { pcgRand3, pcgInit } from '../../nodes/random.wgsl.js'; import { queuedRayStruct, queuedHitStruct, QUEUED_RAY_SIZE, QUEUED_HIT_SIZE } from './structs.js'; +import { proxy } from '../../lib/nodes/NodeProxy.js'; export class RayIntersectionKernel extends ComputeKernel { constructor() { const parameters = { + bvhData: { value: null }, + prevOutputTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadOnly(), outputTarget: textureStore( new StorageTexture( 1, 1 ) ).toWriteOnly(), sampleCountTarget: textureStore( new StorageTexture( 1, 1 ) ).toReadWrite(), @@ -21,12 +23,6 @@ export class RayIntersectionKernel extends ComputeKernel { hitQueue: storage( new IndirectStorageBufferAttribute( 1, QUEUED_HIT_SIZE ), 'QueuedHit' ), hitQueueSize: storage( new IndirectStorageBufferAttribute( 2, 1 ), 'u32' ).toAtomic(), - // bvh and geometry definition - geom_index: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3u' ).toReadOnly(), - geom_position: storage( new IndirectStorageBufferAttribute( 1, 3 ), 'vec3f' ).toReadOnly(), - geom_material_index: storage( new IndirectStorageBufferAttribute( 1, 1 ), 'u32' ).toReadOnly(), - bvh: storage( new IndirectStorageBufferAttribute(), 'BVHNode' ).toReadOnly(), // TODO: fill in sizes - globalId: globalId, }; @@ -46,12 +42,6 @@ export class RayIntersectionKernel extends ComputeKernel { hitQueue: ptr, read_write>, hitQueueSize: ptr>, read_write>, - // scene - geom_position: ptr, read>, - geom_index: ptr, read>, - geom_material_index: ptr, read>, - bvh: ptr, read>, - globalId: vec3u ) -> void { @@ -73,18 +63,17 @@ export class RayIntersectionKernel extends ComputeKernel { pcgInitialize( indexUV, seed ); // run intersection - let hitResult = bvhIntersectFirstHit( geom_index, geom_position, bvh, input.ray ); + let hitResult = bvh_RaycastFirstHit( input.ray ); if ( hitResult.didHit ) { // TODO: we process all of these materials immediately to push to the ray queue - let materialIndex = geom_material_index[ hitResult.indices.x ]; let index = atomicAdd( &hitQueueSize[ 1 ], 1 ); hitQueue[ index ].view = - input.ray.direction; hitQueue[ index ].indices = hitResult.indices.xyz; hitQueue[ index ].barycoord = hitResult.barycoord; hitQueue[ index ].pixel_x = input.pixel.x; hitQueue[ index ].pixel_y = input.pixel.y; - hitQueue[ index ].materialIndex = materialIndex; + hitQueue[ index ].objectIndex = hitResult.objectIndex; hitQueue[ index ].throughputColor = input.throughputColor; hitQueue[ index ].currentBounce = input.currentBounce;; @@ -103,7 +92,10 @@ export class RayIntersectionKernel extends ComputeKernel { } } - `, [ queuedRayStruct, bvhIntersectFirstHit, constants, pcgRand3, pcgInit, queuedHitStruct ] ); + `, [ + proxy( 'bvhData.value.fns.raycastFirstHit', parameters ), + proxy( 'bvhData.value.structs.material', parameters ), + queuedRayStruct, pcgRand3, pcgInit, queuedHitStruct ] ); super( fn( parameters ) ); diff --git a/src/webgpu/compute/wavefront/structs.js b/src/webgpu/compute/wavefront/structs.js index f84308cb2..06e69ef2f 100644 --- a/src/webgpu/compute/wavefront/structs.js +++ b/src/webgpu/compute/wavefront/structs.js @@ -1,5 +1,5 @@ import { wgsl } from 'three/tsl'; -import { rayStruct } from 'three-mesh-bvh/webgpu'; +import { rayStruct } from '../../lib/wgsl/structs.wgsl.js'; export const QUEUED_RAY_SIZE = 16; @@ -23,6 +23,6 @@ export const queuedHitStruct = wgsl( /* wgsl */` view: vec3f, currentBounce: u32, throughputColor: vec3f, - materialIndex: u32, + objectIndex: u32, }; ` ); diff --git a/src/webgpu/lib/BVHComputeData.js b/src/webgpu/lib/BVHComputeData.js new file mode 100644 index 000000000..71434f2c4 --- /dev/null +++ b/src/webgpu/lib/BVHComputeData.js @@ -0,0 +1,769 @@ +import { Matrix4, Vector4 } from 'three'; +import { StorageBufferAttribute, StructTypeNode } from 'three/webgpu'; +import { storage } from 'three/tsl'; +import { rayIntersectsBounds, constants } from './wgsl/common.wgsl.js'; +import { rayStruct, bvhNodeStruct } from './wgsl/structs.wgsl.js'; +import { wgslTagCode, wgslTagFn } from './nodes/WGSLTagFnNode.js'; + +// TODO: add ability to easily update a single matrix / scene rearrangement (partial update) +// TODO: add material support w/ function to easily update material +// - add a callback for writing a property for a geometry to a range +// TODO: add skinned mesh bvh support + +// temporary shim so StructTypeNodes can be passed to storage functions until +// this is fixed in three.js +Object.defineProperty( StructTypeNode.prototype, 'layout', { + + get() { + + return this; + + } + +} ); +StructTypeNode.prototype.isStruct = true; + +// + +// structs +const transformStruct = new StructTypeNode( { + matrixWorld: 'mat4x4f', + inverseMatrixWorld: 'mat4x4f', + nodeOffset: 'uint', + _alignment0: 'uint', + _alignment1: 'uint', + _alignment2: 'uint', +}, 'TransformStruct' ); + +const intersectionResultStruct = new StructTypeNode( { + indices: 'vec4u', + normal: 'vec3f', + didHit: 'bool', + barycoord: 'vec3f', + objectIndex: 'uint', + side: 'float', + dist: 'float', +}, 'IntersectionResult' ); + +// + +// node constants +const BYTES_PER_NODE = 6 * 4 + 4 + 4; +const UINT32_PER_NODE = BYTES_PER_NODE / 4; +const IS_LEAFNODE_FLAG = 0xFFFF; + +// scratch +const _def = /* @__PURE__ */ new Vector4(); +const _vec = /* @__PURE__ */ new Vector4(); +const _matrix = /* @__PURE__ */ new Matrix4(); +const _inverseMatrix = /* @__PURE__ */ new Matrix4(); + +// functions +function dereferenceIndex( indexAttr, indirectBuffer ) { + + const indexArray = indexAttr ? indexAttr.array : null; + const result = new Uint32Array( indirectBuffer.length * 3 ); + for ( let i = 0, l = indirectBuffer.length; i < l; i ++ ) { + + const i3 = 3 * i; + const v3 = 3 * indirectBuffer[ i ]; + for ( let c = 0; c < 3; c ++ ) { + + result[ i3 + c ] = indexArray ? indexArray[ v3 + c ] : v3 + c; + + } + + } + + return result; + +} + +function getTotalBVHByteLength( bvh ) { + + return bvh._roots.reduce( ( v, root ) => v + root.byteLength, 0 ); + +} + +const intersectsTriangle = wgslTagFn/* wgsl */ ` + // fn + fn intersectsTriangle( ray: ${ rayStruct }, a: vec3f, b: vec3f, c: vec3f ) -> ${ intersectionResultStruct } { + + var TRI_INTERSECT_EPSILON = ${ constants.TRI_INTERSECT_EPSILON }; + var result: ${ intersectionResultStruct }; + result.didHit = false; + + let edge1 = b - a; + let edge2 = c - a; + let n = cross( edge1, edge2 ); + + let det = - dot( ray.direction, n ); + + if ( abs( det ) < TRI_INTERSECT_EPSILON ) { + + return result; + + } + + let invdet = 1.0 / det; + + let AO = ray.origin - a; + let DAO = cross( AO, ray.direction ); + + let u = dot( edge2, DAO ) * invdet; + let v = -dot( edge1, DAO ) * invdet; + let t = dot( AO, n ) * invdet; + + let w = 1.0 - u - v; + + if ( u < - TRI_INTERSECT_EPSILON || v < - TRI_INTERSECT_EPSILON || w < - TRI_INTERSECT_EPSILON || t < TRI_INTERSECT_EPSILON ) { + + return result; + + } + + result.didHit = true; + result.barycoord = vec3f( w, u, v ); + result.dist = t; + result.side = sign( det ); + result.normal = result.side * normalize( n ); + + return result; + + } +`; + +export class BVHComputeData { + + constructor( bvh, options = {} ) { + + const { + prefix = 'bvh_', + attributes = { position: 'vec4f' }, + } = options; + + this.prefix = prefix; + this.attributes = attributes; + this.bvh = bvh; + + this.storage = { + index: null, + attributes: null, + nodes: null, + transforms: null, + }; + + this.structs = { + transform: transformStruct, + attributes: null, + }; + + this.fns = { + raycastFirstHit: null, + }; + + } + + getShapecastFn( options ) { + + const { + name, + shapeStruct, + resultStruct, + + boundsOrderFn, + intersectsBoundsFn, + intersectRangeFn, + transformShapeFn, + transformResultFn, + } = options; + + const { storage } = this; + const { BVH_STACK_DEPTH, INFINITY } = constants; + const getFnBody = leafSnippet => { + + // returns a function with a snippet inserted for the leaf intersection test + return wgslTagCode/* wgsl */` + var bestHit: ${ resultStruct }; + bestHit.didHit = false; + bestHit.dist = bestDist; + + var pointer: i32 = 0; + var stack: array; + stack[ 0 ] = rootNodeIndex; + + loop { + + if ( pointer < 0 || pointer >= i32( ${ BVH_STACK_DEPTH } ) ) { + + break; + + } + + let nodeIndex = stack[ pointer ]; + let node = ${ storage.nodes }[ nodeIndex ]; + pointer = pointer - 1; + + var boundsHitDist: f32 = 0.0; + if ( ! ${ intersectsBoundsFn }( shape, node.bounds, &boundsHitDist ) || boundsHitDist > bestHit.dist ) { + + continue; + + } + + let infoX = node.splitAxisOrTriangleCount; + let infoY = node.rightChildOrTriangleOffset; + let isLeaf = ( infoX & 0xffff0000u ) != 0u; + + if ( isLeaf ) { + + let count = infoX & 0x0000ffffu; + let offset = infoY; + ${ leafSnippet } + + } else { + + let leftIndex = nodeIndex + 1u; + let splitAxis = infoX & 0x0000ffffu; + let rightIndex = nodeIndex + infoY; + + let leftToRight = ${ boundsOrderFn }( shape, splitAxis, node ); + let c1 = select( rightIndex, leftIndex, leftToRight ); + let c2 = select( leftIndex, rightIndex, leftToRight ); + + pointer = pointer + 1; + stack[ pointer ] = c2; + + pointer = pointer + 1; + stack[ pointer ] = c1; + + } + + } + + return bestHit; + `; + + }; + + const blasFn = wgslTagFn/* wgsl */` + // fn + fn ${ name }_blas( shape: ${ shapeStruct }, rootNodeIndex: u32, bestDist: f32 ) -> ${ resultStruct } { + + ${ getFnBody( wgslTagCode/* wgsl */` + + let result = ${ intersectRangeFn }( shape, offset, count, bestDist ); + if ( result.didHit && result.dist < bestHit.dist ) { + + bestHit = result; + + } + + ` ) } + + } + `; + + const tlasFn = wgslTagFn/* wgsl */` + // fn + fn ${ name }( shape: ${ shapeStruct } ) -> ${ resultStruct } { + + let bestDist = ${ INFINITY }; + let rootNodeIndex = 0u; + + ${ getFnBody( wgslTagCode/* wgsl */` + + for ( var t = offset; t < offset + count; t = t + 1u ) { + + let transform = ${ storage.transforms }[ t ]; + + // Transform shape into object local space + let localShape = ${ transformShapeFn }( shape, transform.inverseMatrixWorld ); + let blasHit = ${ blasFn( { shape: 'localShape', rootNodeIndex: 'transform.nodeOffset', bestDist: 'bestHit.dist' } ) }; + if ( blasHit.didHit && blasHit.dist < bestHit.dist ) { + + bestHit = blasHit; + bestHit.objectIndex = t; + + ${ transformResultFn }( &bestHit, transform.matrixWorld, transform.inverseMatrixWorld ); + + } + + } + + ` ) } + + } + `; + + return tlasFn; + + } + + update() { + + const self = this; + const { attributes, structs, prefix, bvh } = this; + + // collect the BVHs + const bvhInfo = []; + const transformInfo = []; + + // accumulate the sizes of the bvh nodes buffer, number of objects, and geometry buffers + let bvhNodesBufferLength = getTotalBVHByteLength( bvh ); + let indexBufferLength = 0; + let attributesBufferLength = 0; + bvh.primitiveBuffer.forEach( compositeId => { + + const object = bvh.getObjectFromId( compositeId ); + const instanceId = bvh.getInstanceFromId( compositeId ); + const range = { start: 0, count: 0, vertexStart: 0, vertexCount: 0 }; + const primBvh = this.getBVH( object, instanceId, range ); + + // if we haven't added this bvh, yet + if ( ! bvhInfo.find( info => info.bvh === primBvh ) ) { + + // save the geometry info to write later and increment the buffer sizes + const info = { + index: bvhInfo.length, + bvh: primBvh, + range: range, + + bvhBufferOffsets: null, + indexBufferOffset: null, + + }; + + // increase the buffer sizes for bvh and geometry + bvhNodesBufferLength += getTotalBVHByteLength( primBvh ); + indexBufferLength += info.range.count; + attributesBufferLength += info.range.vertexCount; + bvhInfo.push( info ); + + } + + // save the index of the bvh associated with this transform + const data = bvhInfo.find( info => primBvh === info.bvh ); + primBvh._roots.forEach( ( root, i ) => { + + transformInfo.push( { + data, + root: i, + object, + instanceId, + compositeId, + } ); + + } ); + + } ); + + // + + // construct the attribute struct + const attributeStruct = new StructTypeNode( attributes, `${ prefix }GeometryStruct` ); + + // write the geometry buffer attributes & bvh data + let attributesOffset = 0; + let indexOffset = 0; + let nodeWriteOffset = 0; + const indexBuffer = new Uint32Array( indexBufferLength ); + const attributesBuffer = new ArrayBuffer( attributesBufferLength * attributeStruct.getLength() * 4 ); + const bvhNodesBuffer = new ArrayBuffer( bvhNodesBufferLength ); + + // append TLAS data + appendBVHData( bvh, 0, transformInfo, 0, bvhNodesBuffer, true ); + nodeWriteOffset += getTotalBVHByteLength( bvh ) / BYTES_PER_NODE; + bvhInfo.forEach( info => { + + // append bvh data + const bvhNodeOffsets = appendBVHData( info.bvh, indexOffset / 3, transformInfo, nodeWriteOffset, bvhNodesBuffer, false ); + info.bvhNodeOffsets = bvhNodeOffsets; + + // append geometry data + appendIndexData( info.bvh, info.range, attributesOffset, indexOffset, indexBuffer ); + appendGeometryData( info.bvh, info.range, attributesOffset, attributesBuffer ); + info.indexBufferOffset = indexOffset; + + // step the write offsets forward + indexOffset += info.range.count; + attributesOffset += info.range.vertexCount; + nodeWriteOffset += getTotalBVHByteLength( info.bvh ) / BYTES_PER_NODE; + + } ); + + // + + // write the transforms + const transformArrayBuffer = new ArrayBuffer( structs.transform.getLength() * transformInfo.length * 4 ); + transformInfo.forEach( ( info, i ) => { + + _inverseMatrix.copy( bvh.matrixWorld ).invert(); + this.writeTransformData( info, _inverseMatrix, i, transformArrayBuffer ); + + } ); + + // + + // set up the storage buffers + const bvhNodesStorage = storage( new StorageBufferAttribute( new Uint32Array( bvhNodesBuffer ), 8 ), bvhNodeStruct ).toReadOnly().setName( `${ prefix }nodes` ); + const transformsStorage = storage( new StorageBufferAttribute( new Uint32Array( transformArrayBuffer ), structs.transform.getLength() ), structs.transform ).toReadOnly().setName( `${ prefix }transforms` ); + const indexStorage = storage( new StorageBufferAttribute( indexBuffer, 1 ), 'uint' ).toReadOnly().setName( `${ prefix }index` ); + const attributesStorage = storage( new StorageBufferAttribute( new Uint32Array( attributesBuffer ), attributeStruct.getLength() ), attributeStruct ).toReadOnly().setName( `${ prefix }attributes` ); + + this.storage.transforms = transformsStorage; + this.storage.nodes = bvhNodesStorage; + this.storage.index = indexStorage; + this.storage.attributes = attributesStorage; + this.structs.attributes = attributeStruct; + + this._initFns(); + + function appendBVHData( bvh, geometryOffset, transformInfo, nodeWriteOffset, target, tlas = false ) { + + const targetU16 = new Uint16Array( target ); + const targetU32 = new Uint32Array( target ); + const targetF32 = new Float32Array( target ); + + const result = []; + let tlasOffset = 0; + bvh._roots.forEach( root => { + + const rootBuffer16 = new Uint16Array( root ); + const rootBuffer32 = new Uint32Array( root ); + result.push( nodeWriteOffset ); + for ( let i = 0, l = root.byteLength / BYTES_PER_NODE; i < l; i ++ ) { + + const r32 = i * UINT32_PER_NODE; + const r16 = r32 * 2; + const n32 = nodeWriteOffset * UINT32_PER_NODE; + const n16 = n32 * 2; + + // write bounds + targetF32.set( new Float32Array( root, i * BYTES_PER_NODE, 6 ), n32 ); + + const isLeaf = IS_LEAFNODE_FLAG === rootBuffer16[ r16 + 15 ]; + if ( isLeaf ) { + + if ( tlas ) { + + // 0xFFFF == mesh leaf, 0xFF00 == TLAS leaf + targetU32[ n32 + 6 ] = tlasOffset; + targetU16[ n16 + 15 ] = 0xFF00; + + const count = rootBuffer16[ r16 + 14 ]; + // const offset = rootBuffer32[ r32 + 6 ]; + + // each root is expanded into a separate transform so we need to expand + // the embedded offsets and counts. + let rootsCount = 0; + for ( let o = 0; o < count; o ++ ) { + + const roots = transformInfo[ tlasOffset ].data.bvh._roots.length; + tlasOffset += roots; + rootsCount += roots; + + } + + targetU16[ n16 + 14 ] = rootsCount; + + } else { + + targetU32[ n32 + 6 ] = rootBuffer32[ r32 + 6 ] + geometryOffset; + targetU16[ n16 + 14 ] = rootBuffer16[ r16 + 14 ]; + targetU16[ n16 + 15 ] = IS_LEAFNODE_FLAG; + + } + + } else { + + targetU32[ n32 + 6 ] = rootBuffer32[ r32 + 6 ]; + targetU32[ n32 + 7 ] = rootBuffer32[ r32 + 7 ]; + + } + + nodeWriteOffset ++; + + } + + } ); + + return result; + + } + + function appendIndexData( bvh, range, valueOffset, writeOffset, target ) { + + const { geometry } = bvh; + const { start, count, vertexStart } = range; + if ( bvh.indirect ) { + + const dereferencedIndex = dereferenceIndex( geometry.index, bvh._indirectBuffer ); + for ( let i = 0; i < dereferencedIndex.length; i ++ ) { + + target[ i + writeOffset ] = dereferencedIndex[ i ] - vertexStart + valueOffset; + + } + + } else if ( geometry.index ) { + + for ( let i = 0; i < count; i ++ ) { + + target[ i + writeOffset ] = geometry.index.getX( i + start ) - vertexStart + valueOffset; + + } + + } else { + + for ( let i = 0; i < count; i ++ ) { + + target[ i + writeOffset ] = i + start + valueOffset; + + } + + } + + } + + function appendGeometryData( bvh, range, writeOffset, target ) { + + // if "mesh" is present then it is assumed to be a SkinnedMeshBVH + const { geometry, mesh = null } = bvh; + const { vertexStart, vertexCount } = range; + const attributesBufferF32 = new Float32Array( target ); + attributeStruct.membersLayout.forEach( ( { name }, interleavedOffset ) => { + + // TODO: we should be able to have access to memory layout offsets here via the struct + // API but it's not currently available. + const attr = geometry.attributes[ name ]; + self.getDefaultAttributeValue( name, _def ); + + for ( let i = 0; i < vertexCount; i ++ ) { + + if ( attr ) { + + if ( name === 'position' && mesh ) { + + // TODO: normals and tangents need to be transformed here, as well + mesh.getVertexPosition( i + vertexStart, _vec ); + + } else { + + _vec.fromBufferAttribute( attr, i + vertexStart ); + + } + + switch ( attr.itemSize ) { + + case 1: + _vec.y = _def.y; + _vec.z = _def.z; + _vec.w = _def.w; + break; + case 2: + _vec.z = _def.z; + _vec.w = _def.w; + break; + case 3: + _vec.w = _def.w; + break; + + } + + } else { + + _vec.copy( _def ); + + } + + _vec.toArray( attributesBufferF32, ( writeOffset + i ) * attributeStruct.getLength() + interleavedOffset * 4 ); + + } + + } ); + + } + + } + + _initFns() { + + const { storage, structs, fns, prefix } = this; + + // raycast first hit + fns.raycastFirstHit = this.getShapecastFn( { + name: prefix + 'RaycastFirstHit', + shapeStruct: rayStruct, + resultStruct: intersectionResultStruct, + + boundsOrderFn: wgslTagFn/* wgsl */` + fn getBoundsOrder( ray: ${ rayStruct }, splitAxis: u32, node: ${ bvhNodeStruct } ) -> bool { + + return ray.direction[ splitAxis ] >= 0.0; + + } + `, + intersectsBoundsFn: rayIntersectsBounds, + intersectRangeFn: wgslTagFn/* wgsl */` + fn intersectRange( ray: ${ rayStruct }, offset: u32, count: u32, bestDist: f32 ) -> ${ intersectionResultStruct } { + + var bestHit: ${ intersectionResultStruct }; + bestHit.didHit = false; + bestHit.dist = bestDist; + + for ( var ti = offset; ti < offset + count; ti = ti + 1u ) { + + let i0 = ${ storage.index }[ ti * 3u ]; + let i1 = ${ storage.index }[ ti * 3u + 1u ]; + let i2 = ${ storage.index }[ ti * 3u + 2u ]; + + let a = ${ storage.attributes }[ i0 ].position.xyz; + let b = ${ storage.attributes }[ i1 ].position.xyz; + let c = ${ storage.attributes }[ i2 ].position.xyz; + + var triResult = ${ intersectsTriangle }( ray, a, b, c ); + if ( triResult.didHit && triResult.dist < bestHit.dist ) { + + bestHit = triResult; + bestHit.indices = vec4u( i0, i1, i2, ti ); + + } + + } + + return bestHit; + + } + `, + transformShapeFn: wgslTagFn/* wgsl */` + fn transformRay( ray: ${ rayStruct }, toLocal: mat4x4f ) -> ${ rayStruct } { + + var localRay: Ray; + localRay.origin = ( toLocal * vec4f( ray.origin, 1.0 ) ).xyz; + localRay.direction = ( toLocal * vec4f( ray.direction, 0.0 ) ).xyz; + return localRay; + + } + `, + transformResultFn: wgslTagFn/* wgsl */` + fn transformResult( hit: ptr, toWorld: mat4x4f, toLocal: mat4x4f ) -> void { + + hit.normal = normalize( ( transpose( toLocal ) * vec4f( hit.normal, 0.0 ) ).xyz ); + + } + `, + } ); + + const interpolateBody = structs + .attributes + .membersLayout + .map( ( { name } ) => { + + return `result.${ name } = a0.${ name } * barycoord.x + a1.${ name } * barycoord.y + a2.${ name } * barycoord.z;`; + + } ).join( '\n' ); + fns.sampleTrianglePoint = wgslTagFn/* wgsl */` + // fn + fn ${ prefix }sampleTrianglePoint( barycoord: vec3f, indices: vec3u ) -> ${ structs.attributes } { + + var result: ${ structs.attributes }; + var a0 = ${ storage.attributes }[ indices.x ]; + var a1 = ${ storage.attributes }[ indices.y ]; + var a2 = ${ storage.attributes }[ indices.z ]; + ${ interpolateBody } + return result; + + } + `; + + } + + writeTransformData( info, premultiplyMatrix, writeOffset, targetBuffer ) { + + const { structs } = this; + const transformBufferF32 = new Float32Array( targetBuffer ); + const transformBufferU32 = new Uint32Array( targetBuffer ); + + const { object, instanceId, root, data } = info; + const { bvhNodeOffsets } = data; + + if ( object.isInstancedMesh || object.isBatchedMesh ) { + + object.getMatrixAt( instanceId, _matrix ); + _matrix.premultiply( object.matrixWorld ); + + } else { + + _matrix.copy( object.matrixWorld ); + + } + + _matrix.premultiply( premultiplyMatrix ); + _matrix.toArray( transformBufferF32, writeOffset * structs.transform.getLength() ); + + _matrix.invert(); + _matrix.toArray( transformBufferF32, writeOffset * structs.transform.getLength() + 16 ); + + transformBufferU32[ writeOffset * structs.transform.getLength() + 32 ] = bvhNodeOffsets[ root ]; + + } + + getBVH( object, instanceId, rangeTarget ) { + + let bvh = null; + if ( object.boundsTree ) { + + // TODO + // this is a case where a mesh has morph targets and skinned meshes + + } else if ( object.isBatchedMesh ) { + + const geometryId = object.getGeometryIdAt( instanceId ); + const range = object.getGeometryRangeAt( geometryId ); + Object.assign( rangeTarget, range ); + bvh = object.boundsTrees[ geometryId ]; + + } else { + + const geometry = object.geometry; + rangeTarget.count = geometry.index ? geometry.index.count : geometry.attributes.position.count; + rangeTarget.vertexCount = geometry.attributes.position.count; + bvh = object.geometry.boundsTree; + + } + + if ( ! bvh ) { + + throw new Error( 'BVHComputeData: BVH not found.' ); + + } + + return bvh; + + } + + getDefaultAttributeValue( key, target ) { + + switch ( key ) { + + case 'position': + case 'color': + target.set( 1, 1, 1, 1 ); + break; + + default: + target.set( 0, 0, 0, 0 ); + + } + + return target; + + } + + dispose() { + + // TODO: dispose buffers + + } + +} diff --git a/src/webgpu/lib/ObjectBVH.js b/src/webgpu/lib/ObjectBVH.js new file mode 100644 index 000000000..c0540ed2b --- /dev/null +++ b/src/webgpu/lib/ObjectBVH.js @@ -0,0 +1,615 @@ +import { Box3, BufferGeometry, Matrix4, Mesh, Vector3, Ray, Sphere } from 'three'; +import { BVH, INTERSECTED, NOT_INTERSECTED } from 'three-mesh-bvh'; + +const _geometry = /* @__PURE__ */ new BufferGeometry(); +const _matrix = /* @__PURE__ */ new Matrix4(); +const _inverseMatrix = /* @__PURE__ */ new Matrix4(); +const _box = /* @__PURE__ */ new Box3(); +const _sphere = /* @__PURE__ */ new Sphere(); +const _vec = /* @__PURE__ */ new Vector3(); +const _ray = /* @__PURE__ */ new Ray(); +const _mesh = /* @__PURE__ */ new Mesh(); +const _geometryRange = {}; + +// TODO: account for a "custom" object? Not necessary here? Create a more abstract foundation for this case? +export function objectAcceleratedRaycast( raycaster, intersects ) { + + if ( this.boundsTree ) { + + this.boundsTree.raycast( raycaster, intersects ); + return false; + + } + +} + +export class ObjectBVH extends BVH { + + constructor( root, options = {} ) { + + options = { + precise: false, + includeInstances: true, + matrixWorld: Array.isArray( root ) ? new Matrix4() : root.matrixWorld, + maxLeafSize: 1, + ...options, + }; + + super(); + + // collect all the leaf node objects in the geometries + const objectSet = new Set(); + collectObjects( root, objectSet ); + + // calculate the number of bits required for the primary id, leaving the remainder + // for the instanceId count + const objects = Array.from( objectSet ); + const idBits = Math.ceil( Math.log2( objects.length ) ); + const idMask = constructIdMask( idBits ); + + this.objects = objects; + this.idBits = idBits; + this.idMask = idMask; + this.primitiveBuffer = null; + this.primitiveBufferStride = 1; + + // settings + this.precise = options.precise; + this.includeInstances = options.includeInstances; + this.matrixWorld = options.matrixWorld; + + this.init( options ); + + } + + getObjectFromId( compositeId ) { + + const { idMask, objects } = this; + const id = getObjectId( compositeId, idMask ); + return objects[ id ]; + + } + + getInstanceFromId( compositeId ) { + + const { idMask, idBits } = this; + return getInstanceId( compositeId, idBits, idMask ); + + } + + init( options ) { + + const { objects, idBits } = this; + this.primitiveBuffer = new Uint32Array( this._countPrimitives( objects ) ); + this._fillPrimitiveBuffer( objects, idBits, this.primitiveBuffer ); + + super.init( options ); + + } + + writePrimitiveBounds( i, targetBuffer, writeOffset ) { + + // TODO: it would be best to cache this matrix inversion + const { primitiveBuffer } = this; + _inverseMatrix.copy( this.matrixWorld ).invert(); + + this._getPrimitiveBoundingBox( primitiveBuffer[ i ], _inverseMatrix, _box ); + const { min, max } = _box; + + targetBuffer[ writeOffset + 0 ] = min.x; + targetBuffer[ writeOffset + 1 ] = min.y; + targetBuffer[ writeOffset + 2 ] = min.z; + targetBuffer[ writeOffset + 3 ] = max.x; + targetBuffer[ writeOffset + 4 ] = max.y; + targetBuffer[ writeOffset + 5 ] = max.z; + + } + + getRootRanges() { + + return [ { offset: 0, count: this.primitiveBuffer.length } ]; + + } + + shapecast( callbacks ) { + + return super.shapecast( { + ...callbacks, + + intersectsPrimitive: callbacks.intersectsObject, + scratchPrimitive: null, + iterate: iterateOverObjects, + } ); + + } + + // TODO: this is out of sync with the MeshBVH raycast signature. + raycast( raycaster, intersects = [] ) { + + const { matrixWorld, includeInstances } = this; + const { firstHitOnly } = raycaster; + const localIntersects = []; + + // transform the ray into the local bvh frame + _inverseMatrix.copy( matrixWorld ).invert(); + _ray.copy( raycaster.ray ).applyMatrix4( _inverseMatrix ); + + let closestDistance = Infinity; + let closestHit = null; + + this.shapecast( { + boundsTraverseOrder: box => { + + return box.distanceToPoint( _ray.origin ); + + }, + intersectsBounds: box => { + + if ( firstHitOnly ) { + + if ( ! _ray.intersectBox( box, _vec ) ) { + + return NOT_INTERSECTED; + + } + + // early out if the box is further than the closest raycast + _vec.applyMatrix4( matrixWorld ); + return raycaster.ray.origin.distanceTo( _vec ) < closestDistance ? INTERSECTED : NOT_INTERSECTED; + + } else { + + return _ray.intersectsBox( box ) ? INTERSECTED : NOT_INTERSECTED; + + } + + }, + intersectsObject( object, instanceId ) { + + // skip non visible objects + if ( ! object.visible ) { + + return; + + } + + if ( object.isInstancedMesh && includeInstances ) { + + // raycast the instance + _mesh.geometry = object.geometry; + _mesh.material = object.material; + + object.getMatrixAt( instanceId, _mesh.matrixWorld ); + _mesh.matrixWorld.premultiply( object.matrixWorld ); + _mesh.raycast( raycaster, localIntersects ); + + localIntersects.forEach( hit => { + + hit.object = object; + hit.instanceId = instanceId; + + } ); + + _mesh.material = null; + + } else if ( object.isBatchedMesh && includeInstances ) { + + if ( ! object.getVisibleAt( instanceId ) ) { + + return; + + } + + // extract the geometry & material + const geometryId = object.getGeometryIdAt( instanceId ); + const geometryRange = object.getGeometryRangeAt( geometryId, _geometryRange ); + + _geometry.index = object.geometry.index; + _geometry.attributes.position = object.geometry.attributes.position; + _geometry.setDrawRange( geometryRange.start, geometryRange.count ); + + _mesh.geometry = _geometry; + _mesh.material = object.material; + + // perform a raycast against the proxy mesh + object.getMatrixAt( instanceId, _mesh.matrixWorld ); + _mesh.matrixWorld.premultiply( object.matrixWorld ); + _mesh.raycast( raycaster, localIntersects ); + + // fix up the fields + localIntersects.forEach( hit => { + + hit.object = object; + hit.batchId = instanceId; + + } ); + + _mesh.material = null; + _geometry.index = null; + _geometry.attributes.position = null; + _geometry.setDrawRange( 0, Infinity ); + + } else { + + object.raycast( raycaster, localIntersects ); + + } + + // find the closest hit to track + if ( firstHitOnly ) { + + localIntersects.forEach( hit => { + + if ( hit.distance < closestDistance ) { + + closestDistance = hit.distance; + closestHit = hit; + + } + + } ); + + } else { + + intersects.push( ...localIntersects ); + + } + + }, + } ); + + // save the closest hit only if firstHitOnly = true + if ( firstHitOnly && closestHit ) { + + intersects.push( closestHit ); + + } + + return intersects; + + } + + // get the bounding box of a primitive node accounting for the bvh options + _getPrimitiveBoundingBox( compositeId, inverseMatrixWorld, target ) { + + const { objects, idMask, idBits, precise, includeInstances } = this; + const id = getObjectId( compositeId, idMask ); + const instanceId = getInstanceId( compositeId, idBits, idMask ); + const object = objects[ id ]; + + if ( ! includeInstances && ( object.isInstancedMesh || object.isBatchedMesh ) ) { + + // if we're not using instances then just account for the overall bounds of the BatchedMesh and InstancedMesh + if ( ! object.boundingBox ) { + + object.computeBoundingBox(); + + } + + if ( ! object.boundingSphere ) { + + object.computeBoundingSphere(); + + } + + _matrix + .copy( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + _sphere + .copy( object.boundingSphere ) + .applyMatrix4( _matrix ); + + target + .copy( object.boundingBox ) + .applyMatrix4( _matrix ); + + shrinkToSphere( target, _sphere ); + + } else if ( precise ) { + + // calculate precise bounds if necessary by calculating the bounds of all vertices + // in the bvh frame + if ( object.isInstancedMesh ) { + + object + .getMatrixAt( instanceId, _matrix ); + + _matrix + .premultiply( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + getPreciseBounds( object.geometry, _matrix, target ); + + } else if ( object.isBatchedMesh ) { + + const geometryId = object.getGeometryIdAt( instanceId ); + const geometryRange = object.getGeometryRangeAt( geometryId, _geometryRange ); + + _geometry.index = object.geometry.index; + _geometry.attributes.position = object.geometry.attributes.position; + _geometry.setDrawRange( geometryRange.start, geometryRange.count ); + + object + .getMatrixAt( instanceId, _matrix ); + + _matrix + .premultiply( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + getPreciseBounds( _geometry, _matrix, target ); + + } else { + + _matrix + .copy( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + target.setFromObject( object, true ).applyMatrix4( inverseMatrixWorld ); + + } + + } else { + + // otherwise use the fast path of extracting the cached, AABB bounds and transforming them + // into the local BVH frame + if ( object.isInstancedMesh ) { + + if ( ! object.geometry.boundingBox ) { + + object.geometry.computeBoundingBox(); + + } + + if ( ! object.geometry.boundingSphere ) { + + object.geometry.computeBoundingSphere(); + + } + + object + .getMatrixAt( instanceId, _matrix ); + + _matrix + .premultiply( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + _sphere + .copy( object.geometry.boundingSphere ) + .applyMatrix4( _matrix ); + + target + .copy( object.geometry.boundingBox ) + .applyMatrix4( _matrix ); + + shrinkToSphere( target, _sphere ); + + } else if ( object.isBatchedMesh ) { + + const geometryId = object.getGeometryIdAt( instanceId ); + + object + .getMatrixAt( instanceId, _matrix ); + + _matrix + .premultiply( object.matrixWorld ) + .premultiply( inverseMatrixWorld ); + + object + .getBoundingSphereAt( geometryId, _sphere ) + .applyMatrix4( _matrix ); + + object + .getBoundingBoxAt( geometryId, target ) + .applyMatrix4( _matrix ); + + shrinkToSphere( target, _sphere ); + + } else { + + target + .setFromObject( object, false ) + .applyMatrix4( inverseMatrixWorld ); + + } + + } + + } + + // counts the total number of primitives required by the objects in given array of objects + _countPrimitives( objects ) { + + const { includeInstances } = this; + let total = 0; + objects.forEach( object => { + + if ( object.isInstancedMesh && includeInstances ) { + + total += object.count; + + } else if ( object.isBatchedMesh && includeInstances ) { + + total += object.instanceCount; + + } else { + + total ++; + + } + + } ); + + return total; + + } + + _fillPrimitiveBuffer( objects, idBits, target ) { + + const { includeInstances } = this; + let index = 0; + objects.forEach( ( object, i ) => { + + if ( object.isInstancedMesh && includeInstances ) { + + const count = object.count; + for ( let c = 0; c < count; c ++ ) { + + target[ index ] = ( c << idBits ) | i; + index ++; + + } + + } else if ( object.isBatchedMesh && includeInstances ) { + + const { instanceCount, maxInstanceCount } = object; + let foundInstances = 0; + let iter = 0; + + while ( foundInstances < instanceCount && iter < maxInstanceCount ) { + + iter ++; + + // TODO: it would be better to have a consistent way of querying whether an + // instance were active + try { + + object.getVisibleAt( iter ); + + target[ index ] = ( iter << idBits ) | i; + foundInstances ++; + index ++; + + } catch { + + // + + } + + } + + } else { + + target[ index ] = i; + index ++; + + } + + } ); + + } + +} + +// id functions +// construct a mask with the given number of bits set to 1 +function constructIdMask( idBits ) { + + let mask = 0; + for ( let i = 0; i < idBits; i ++ ) { + + mask = mask << 1 | 1; + + } + + return mask; + +} + +// extract the primary object id given the provided mask +function getObjectId( id, idMask ) { + + return id & idMask; + +} + +// extract the instance id given the mask and number of bits to shift +function getInstanceId( id, idBits, idMask ) { + + return ( id & ( ~ idMask ) ) >> idBits; + +} + +// traverse the full scene and collect all leaves +function collectObjects( root, objectSet = new Set() ) { + + if ( Array.isArray( root ) ) { + + root.forEach( object => collectObjects( object, objectSet ) ); + + } else { + + root.traverse( child => { + + if ( child.isMesh || child.isLine || child.isPoints ) { + + objectSet.add( child ); + + } + + } ); + + } + +} + +// calculate precise box bounds of the given geometry in the given frame +function getPreciseBounds( geometry, matrix, target ) { + + target.makeEmpty(); + + const drawRange = geometry.drawRange; + const indexAttr = geometry.index; + const posAttr = geometry.attributes.position; + const start = drawRange.start; + const vertCount = indexAttr ? indexAttr.count : posAttr.count; + const count = Math.min( vertCount - start, drawRange.count ); + for ( let i = start, l = start + count; i < l; i ++ ) { + + let vi = i; + if ( indexAttr ) { + + vi = indexAttr.getX( vi ); + + } + + _vec.fromBufferAttribute( posAttr, vi ).applyMatrix4( matrix ); + target.expandByPoint( _vec ); + + } + + return target; + +} + +// iterator helper for raycasting +function iterateOverObjects( offset, count, bvh, callback, contained, depth, /* scratch */ ) { + + const { primitiveBuffer, objects, idMask, idBits } = bvh; + for ( let i = offset, l = count + offset; i < l; i ++ ) { + + const compositeId = primitiveBuffer[ i ]; + const id = getObjectId( compositeId, idMask ); + const instanceId = getInstanceId( compositeId, idBits, idMask ); + const object = objects[ id ]; + if ( callback( object, instanceId, contained, depth ) ) { + + return true; + + } + + } + + return false; + +} + +function shrinkToSphere( box, sphere ) { + + _vec.copy( sphere.center ).addScalar( - sphere.radius ); + box.min.max( _vec ); + + _vec.copy( sphere.center ).addScalar( sphere.radius ); + box.max.min( _vec ); + +} diff --git a/src/webgpu/lib/nodes/NodeProxy.js b/src/webgpu/lib/nodes/NodeProxy.js new file mode 100644 index 000000000..ee3a9962e --- /dev/null +++ b/src/webgpu/lib/nodes/NodeProxy.js @@ -0,0 +1,102 @@ +import { Node } from 'three/webgpu'; + +class ProxyCallNode extends Node { + + static get type() { + + return 'ProxyCallNode'; + + } + + constructor( proxyNode, params ) { + + super(); + this.proxyNode = proxyNode; + this.params = params; + + } + + setup() { + + return this.proxyNode.node.call( ...this.params ); + + } + +} + +export class NodeProxy extends Node { + + static get type() { + + return 'NodeProxy'; + + } + + get node() { + + const { properties, object } = this; + let value = object; + for ( let i = 0, l = properties.length; i < l; i ++ ) { + + value = value[ properties[ i ] ]; + + } + + if ( 'functionNode' in value ) { + + return value.functionNode; + + } else { + + return value; + + } + + } + + constructor( property, object = null ) { + + super(); + this.object = object; + this.property = property; + this.properties = property.split( '.' ); + + } + + // delegate type resolution to the target node + getNodeType( builder ) { + + return this.node.getNodeType( builder ); + + } + + // include the target node's cache key so the proxy invalidates when the target changes + customCacheKey() { + + return this.node.getCacheKey(); + + } + + // return the target node as the output so the builder uses it for analyze/generate + setup( builder ) { + + return this.node; + + } + +} + +export const proxy = ( ...args ) => { + + return new NodeProxy( ...args ); + +}; + +export const proxyFn = ( ...args ) => { + + const nodeProxy = new NodeProxy( ...args ); + const fn = ( ...params ) => new ProxyCallNode( nodeProxy, params ); + fn.functionNode = nodeProxy; + return nodeProxy; + +}; diff --git a/src/webgpu/lib/nodes/WGSLTagFnNode.js b/src/webgpu/lib/nodes/WGSLTagFnNode.js new file mode 100644 index 000000000..85a1468a0 --- /dev/null +++ b/src/webgpu/lib/nodes/WGSLTagFnNode.js @@ -0,0 +1,321 @@ +import { CodeNode, FunctionNode, Node } from 'three/webgpu'; + +// minimal node that outputs a raw WGSL expression verbatim when built +class LiteralExpression extends Node { + + constructor( literal ) { + + super(); + this.literal = literal; + + } + + build() { + + return this.literal; + + } + +} + +// wraps a FunctionNode so that build() returns just the function name +class PropertyRefNode extends Node { + + constructor( node ) { + + super(); + this.node = node; + + } + + build( builder ) { + + return this.node.build( builder, 'property' ); + + } + +} + +// wraps a FunctionCallNode so that build() returns the inline call expression, +// bypassing TempNode's variable wrapping +class InlineCallNode extends Node { + + constructor( node ) { + + super(); + this.node = node; + + } + + build( builder ) { + + return this.node.generate( builder ); + + } + +} + +// returns the node that should be registered as an include for the given arg +function getIncludeNode( arg ) { + + if ( typeof arg === 'function' ) { + + if ( arg.functionNode ) return arg.functionNode; + if ( arg.isStruct ) return arg.layout; + return null; + + } + + if ( arg && arg.isNode ) { + + if ( arg.functionNode ) return arg.functionNode; + if ( arg.isStructLayoutNode || arg.isCodeNode ) return arg; + + } + + return null; + +} + +// extract dependency nodes from template args for include registration +function extractIncludes( args ) { + + const includes = []; + for ( const arg of args ) { + + if ( Array.isArray( arg ) ) { + + for ( const element of arg ) { + + const node = getIncludeNode( element ); + if ( node ) includes.push( node ); + + } + + } else { + + const node = getIncludeNode( arg ); + if ( node ) includes.push( node ); + + } + + } + + return includes; + +} + +// normalize args so generate can resolve them uniformly with build(): +// - callable wrappers > PropertyRefNode (emits just the function name) +// - struct callables > StructTypeNode (emits the type name via build) +// - FunctionCallNodes > InlineCallNode (emits inline call) +function normalizeArgs( args ) { + + return args.map( arg => { + + if ( typeof arg === 'function' && arg.functionNode ) return new PropertyRefNode( arg.functionNode ); + if ( typeof arg === 'function' && arg.isStruct ) return arg.layout; + if ( arg && arg.isNode && arg.functionNode ) return new InlineCallNode( arg ); + return arg; + + } ); + +} + +// interleave static tokens with resolved arg values +function assembleTemplate( tokens, args, builder ) { + + let code = ''; + for ( let i = 0, l = tokens.length; i < l; i ++ ) { + + code += tokens[ i ]; + if ( i < args.length ) { + + const arg = args[ i ]; + if ( Array.isArray( arg ) ) { + + // include array — no text output + + } else if ( typeof arg === 'string' || typeof arg === 'number' ) { + + code += String( arg ); + + } else { + + code += arg.build( builder ); + + } + + } + + } + + return code; + +} + +export class WGSLTagFnNode extends FunctionNode { + + static get type() { + + return 'WGSLFnTagNode'; + + } + + constructor( tokens, args, lang = 'wgsl' ) { + + super( '', extractIncludes( args ), lang ); + + this.tokens = tokens; + this.args = normalizeArgs( args ); + + } + + // assemble the signature from tokens and arg names then parse + getNodeFunction( builder ) { + + const { tokens, args } = this; + const nodeData = builder.getDataFromNode( this ); + let nodeFunction = nodeData.nodeFunction; + if ( nodeFunction === undefined ) { + + // reconstruct the full code with known names for struct args + // and dummy identifiers for everything else + let fullCode = ''; + for ( let i = 0, l = tokens.length; i < l; i ++ ) { + + fullCode += tokens[ i ]; + + if ( i < args.length ) { + + const arg = args[ i ]; + if ( Array.isArray( arg ) ) { + + // include array — no text output + + } else if ( typeof arg === 'string' || typeof arg === 'number' ) { + + // literals + fullCode += String( arg ); + + } else if ( arg.isStructLayoutNode ) { + + // struct type node + fullCode += arg.getNodeType( builder ); + + } else if ( arg.isStruct ) { + + // struct + fullCode += arg.layout.getNodeType( builder ); + + } else { + + fullCode += '_arg' + i; + + } + + } + + } + + // remove comments + fullCode = fullCode.replace( /\/\/.+[\n\r]/g, '' ); + + // parse it so we have the signature defined - we will define the body content after + nodeFunction = builder.parser.parseFunction( fullCode ); + nodeData.nodeFunction = nodeFunction; + + } + + return nodeFunction; + + } + + // get the code for the function + generate( builder, output ) { + + const result = super.generate( builder, output ); + const fullCode = assembleTemplate( this.tokens, this.args, builder ); + + const { type } = this.getNodeFunction( builder ); + const nodeCode = builder.getCodeFromNode( this, type ); + + nodeCode.code = fullCode.replace( /\/\/.+[\n\r]/g, '' ).replace( /->\s*void/, '' ).replace( /\s+/g, ' ' ).trim(); + + return result; + + } + +} + +export class WGSLTagCodeNode extends CodeNode { + + static get type() { + + return 'WGSLTagCodeNode'; + + } + + constructor( tokens, args, lang = 'wgsl' ) { + + super( '', extractIncludes( args ), lang ); + + this.tokens = tokens; + this.args = normalizeArgs( args ); + + } + + generate( builder ) { + + // build includes so dependencies are registered before the parent code block + const includes = this.getIncludes( builder ); + for ( const include of includes ) { + + include.build( builder ); + + } + + return assembleTemplate( this.tokens, this.args, builder ); + + } + +} + +const getFn = functionNode => { + + const fn = ( ...params ) => { + + // wrap string parameter values as raw WGSL expressions so they + // output verbatim as identifiers like local variable names + if ( params.length === 1 && params[ 0 ] && typeof params[ 0 ] === 'object' && ! params[ 0 ].isNode ) { + + const obj = params[ 0 ]; + for ( const key in obj ) { + + if ( typeof obj[ key ] === 'string' ) { + + obj[ key ] = new LiteralExpression( obj[ key ] ); + + } + + } + + } + + return functionNode.call( ...params ); + + }; + + fn.functionNode = functionNode; + return fn; + +}; + +// template tag literal function version of "wgslFn" & "wgsl" to generate +// functions & code snippets respectively +export const wgslTagFn = ( tokens, ...args ) => getFn( new WGSLTagFnNode( tokens, args ) ); +export const wgslTagCode = ( tokens, ...args ) => new WGSLTagCodeNode( tokens, args ); + +// glsl versions +export const glslTagFn = ( tokens, ...args ) => getFn( new WGSLTagFnNode( tokens, args, 'glsl' ) ); +export const glslTagCode = ( tokens, ...args ) => new WGSLTagCodeNode( tokens, args, 'glsl' ); diff --git a/src/webgpu/lib/wgsl/common.wgsl.js b/src/webgpu/lib/wgsl/common.wgsl.js new file mode 100644 index 000000000..e34eb4f20 --- /dev/null +++ b/src/webgpu/lib/wgsl/common.wgsl.js @@ -0,0 +1,68 @@ +import { wgslFn, uint, float } from 'three/tsl'; +import { bvhNodeBoundsStruct, rayStruct } from './structs.wgsl.js'; + +export const constants = { + BVH_STACK_DEPTH: uint( 60 ), + INFINITY: float( 1e20 ), + TRI_INTERSECT_EPSILON: float( 1e-5 ), +}; + +export const ndcToCameraRay = wgslFn( /* wgsl*/` + + fn ndcToCameraRay( ndc: vec2f, inverseModelViewProjection: mat4x4f ) -> Ray { + + // Calculate the ray by picking the points at the near and far plane and deriving the ray + // direction from the two points. This approach works for both orthographic and perspective + // camera projection matrices. + // The returned ray direction is not normalized and extends to the camera far plane. + var homogeneous = vec4f(); + var ray = Ray(); + + homogeneous = inverseModelViewProjection * vec4f( ndc, 0.0, 1.0 ); + ray.origin = homogeneous.xyz / homogeneous.w; + + homogeneous = inverseModelViewProjection * vec4f( ndc, 1.0, 1.0 ); + ray.direction = ( homogeneous.xyz / homogeneous.w ) - ray.origin; + + return ray; + + } +`, [ rayStruct ] ); + +export const rayIntersectsBounds = wgslFn( /* wgsl */` + + fn rayIntersectsBounds( + ray: Ray, + bounds: BVHBoundingBox, + dist: ptr + ) -> bool { + + let boundsMin = vec3( bounds.min[0], bounds.min[1], bounds.min[2] ); + let boundsMax = vec3( bounds.max[0], bounds.max[1], bounds.max[2] ); + + let invDir = 1.0 / ray.direction; + let tMinPlane = ( boundsMin - ray.origin ) * invDir; + let tMaxPlane = ( boundsMax - ray.origin ) * invDir; + + let tMinHit = vec3f( + min( tMinPlane.x, tMaxPlane.x ), + min( tMinPlane.y, tMaxPlane.y ), + min( tMinPlane.z, tMaxPlane.z ) + ); + + let tMaxHit = vec3f( + max( tMinPlane.x, tMaxPlane.x ), + max( tMinPlane.y, tMaxPlane.y ), + max( tMinPlane.z, tMaxPlane.z ) + ); + + let t0 = max( max( tMinHit.x, tMinHit.y ), tMinHit.z ); + let t1 = min( min( tMaxHit.x, tMaxHit.y ), tMaxHit.z ); + + ( *dist ) = max( t0, 0.0 ); + + return t1 >= ( *dist ); + + } + +`, [ rayStruct, bvhNodeBoundsStruct ] ); diff --git a/src/webgpu/lib/wgsl/structs.wgsl.js b/src/webgpu/lib/wgsl/structs.wgsl.js new file mode 100644 index 000000000..040f36be3 --- /dev/null +++ b/src/webgpu/lib/wgsl/structs.wgsl.js @@ -0,0 +1,28 @@ +import { StructTypeNode } from 'three/webgpu'; + +export const rayStruct = new StructTypeNode( { + origin: 'vec3f', + direction: 'vec3f', +}, 'Ray' ); + +export const bvhNodeBoundsStruct = new StructTypeNode( { + min: 'array', + max: 'array', +}, 'BVHBoundingBox' ); +bvhNodeBoundsStruct.getLength = () => 6; + +export const bvhNodeStruct = new StructTypeNode( { + bounds: 'BVHBoundingBox', + rightChildOrTriangleOffset: 'uint', + splitAxisOrTriangleCount: 'uint', +}, 'BVHNode' ); +bvhNodeStruct.getLength = () => bvhNodeBoundsStruct.getLength() + 2; + +export const intersectionResultStruct = new StructTypeNode( { + didHit: 'bool', + indices: 'vec4u', + normal: 'vec3f', + barycoord: 'vec3f', + side: 'float', + dist: 'float', +}, 'IntersectionResult' ); diff --git a/src/webgpu/nodes/PathtracerBVHComputeData.js b/src/webgpu/nodes/PathtracerBVHComputeData.js new file mode 100644 index 000000000..e5ae9b306 --- /dev/null +++ b/src/webgpu/nodes/PathtracerBVHComputeData.js @@ -0,0 +1,131 @@ +import { BufferAttribute, BufferGeometry, StorageBufferAttribute, StructTypeNode } from 'three/webgpu'; +import { BVHComputeData } from '../lib/BVHComputeData.js'; +import { storage } from 'three/tsl'; +import { MeshBVH, SAH } from 'three-mesh-bvh'; + +const transformStruct = new StructTypeNode( { + matrixWorld: 'mat4x4f', + inverseMatrixWorld: 'mat4x4f', + nodeOffset: 'uint', + materialIndex: 'uint', + _alignment0: 'uint', + _alignment1: 'uint', +}, 'TransformStruct' ); + +const materialStruct = new StructTypeNode( { + albedo: 'vec3f', +}, 'MaterialStruct' ); + +// Pathtracer-specific version of the BVHComputeData tht includes material mapping, property structs +export class PathtracerBVHComputeData extends BVHComputeData { + + constructor( bvh, options = {} ) { + + // TODO: once supported we should use the appropriately-sized member sizes + super( bvh, { + attributes: { + position: 'vec4f', + normal: 'vec4f', + uv0: 'vec4f', + }, + ...options, + } ); + + this.structs.transform = transformStruct; + this.structs.material = materialStruct; + this.materials = []; + this.bvhMap = new Map(); + + } + + update() { + + super.update(); + + // build material storage + const { materials, structs, prefix: name } = this; + const materialBuffer = new ArrayBuffer( structs.material.getLength() * materials.length * 4 ); + const materialBufferF32 = new Float32Array( materialBuffer ); + materials.forEach( ( mat, i ) => { + + mat.color.toArray( materialBufferF32, i * structs.material.getLength() ); + + } ); + + const materialStorage = storage( new StorageBufferAttribute( new Uint32Array( materialBuffer ), 8 ), structs.material.name ).toReadOnly().setName( `${ name }materials` ); + this.storage.materials = materialStorage; + + this.bvhMap.clear(); + this.materials.length = 0; + + } + + writeTransformData( info, premultiplyMatrix, writeOffset, targetBuffer ) { + + super.writeTransformData( info, premultiplyMatrix, writeOffset, targetBuffer ); + + // write material data to the transforms + const { materials } = this; + const material = info.object.material; + if ( ! materials.includes( material ) ) { + + materials.push( material ); + + } + + const index = materials.indexOf( material ); + const transformBufferU32 = new Uint32Array( targetBuffer ); + transformBufferU32[ writeOffset * transformStruct.getLength() + 33 ] = index; + + } + + getBVH( object, instanceId, rangeTarget ) { + + const { bvhMap } = this; + const bvh = super.getBVH( object, instanceId, rangeTarget ); + if ( bvhMap.has( bvh ) ) { + + const data = bvhMap.get( bvh ); + Object.assign( rangeTarget, data.range ); + return data.bvh; + + } else if ( bvh.indirect ) { + + // "indirect" bvhs are not supported since they cannot be unpacked in a way tht will allow for coherent material indices + const proxyGeometry = new BufferGeometry(); + proxyGeometry.attributes = bvh.geometry.attributes; + + let array; + if ( bvh.geometry.index ) { + + array = bvh.geometry.index.array.slice( rangeTarget.start, rangeTarget.count + rangeTarget.start ); + + } else { + + const { start, count } = rangeTarget; + array = new Uint32Array( count ); + for ( let i = 0, l = rangeTarget.count; i < l; i ++ ) { + + array[ i ] = start + i; + + } + + } + + proxyGeometry.index = new BufferAttribute( array, 1 ); + rangeTarget.start = 0; + + // TODO: need to handle SkinnedMeshBVH here + const newBVH = new MeshBVH( proxyGeometry, { strategy: SAH, maxLeafSize: 5 } ); + bvhMap.set( bvh, { bvh: newBVH, range: { ...rangeTarget } } ); + return newBVH; + + } else { + + return bvh; + + } + + } + +} diff --git a/src/webgpu/nodes/structs.wgsl.js b/src/webgpu/nodes/structs.wgsl.js index 97cb1ba17..8d872b772 100644 --- a/src/webgpu/nodes/structs.wgsl.js +++ b/src/webgpu/nodes/structs.wgsl.js @@ -2,7 +2,7 @@ import { wgsl } from 'three/tsl'; import { rayStruct } from 'three-mesh-bvh/webgpu'; export const constants = wgsl( /* wgsl */ ` - const PI: f32 = 3.141592653589793; + const PI: f32 = 3.141592653589793; ` ); export const scatterRecordStruct = wgsl( /* wgsl */ `