@@ -18,8 +18,15 @@ use super::backend::{self, ComputeBackend};
1818pub type BufferConfig = backend:: BufferConfig ;
1919pub type BufferUsage = backend:: BufferUsage ;
2020
21- /// Trait that creates a shader module and provides its entry point.
22- pub trait ComputeShader {
21+ /// Trait for shaders that can provide SPIRV bytes.
22+ pub trait SpirvShader {
23+ /// Returns the SPIRV bytes and entry point name.
24+ fn spirv_bytes ( & self ) -> anyhow:: Result < ( Vec < u8 > , String ) > ;
25+ }
26+
27+ /// Trait for shaders that can create wgpu modules.
28+ pub trait WgpuShader {
29+ /// Creates a wgpu shader module.
2330 fn create_module (
2431 & self ,
2532 device : & wgpu:: Device ,
@@ -29,25 +36,46 @@ pub trait ComputeShader {
2936/// A compute shader written in Rust compiled with spirv-builder.
3037pub struct RustComputeShader {
3138 pub path : PathBuf ,
39+ pub target : String ,
40+ pub capabilities : Vec < spirv_builder:: Capability > ,
3241}
3342
3443impl RustComputeShader {
3544 pub fn new < P : Into < PathBuf > > ( path : P ) -> Self {
36- Self { path : path. into ( ) }
45+ Self {
46+ path : path. into ( ) ,
47+ target : "spirv-unknown-vulkan1.1" . to_string ( ) ,
48+ capabilities : Vec :: new ( ) ,
49+ }
50+ }
51+
52+ pub fn with_target < P : Into < PathBuf > > ( path : P , target : impl Into < String > ) -> Self {
53+ Self {
54+ path : path. into ( ) ,
55+ target : target. into ( ) ,
56+ capabilities : Vec :: new ( ) ,
57+ }
58+ }
59+
60+ pub fn with_capability ( mut self , capability : spirv_builder:: Capability ) -> Self {
61+ self . capabilities . push ( capability) ;
62+ self
3763 }
3864}
3965
40- impl ComputeShader for RustComputeShader {
41- fn create_module (
42- & self ,
43- device : & wgpu:: Device ,
44- ) -> anyhow:: Result < ( wgpu:: ShaderModule , Option < String > ) > {
45- let builder = SpirvBuilder :: new ( & self . path , "spirv-unknown-vulkan1.1" )
66+ impl SpirvShader for RustComputeShader {
67+ fn spirv_bytes ( & self ) -> anyhow:: Result < ( Vec < u8 > , String ) > {
68+ let mut builder = SpirvBuilder :: new ( & self . path , & self . target )
4669 . print_metadata ( spirv_builder:: MetadataPrintout :: None )
4770 . release ( true )
4871 . multimodule ( false )
4972 . shader_panic_strategy ( spirv_builder:: ShaderPanicStrategy :: SilentExit )
5073 . preserve_bindings ( true ) ;
74+
75+ for capability in & self . capabilities {
76+ builder = builder. capability ( * capability) ;
77+ }
78+
5179 let artifact = builder. build ( ) . context ( "SpirvBuilder::build() failed" ) ?;
5280
5381 if artifact. entry_points . len ( ) != 1 {
@@ -66,6 +94,17 @@ impl ComputeShader for RustComputeShader {
6694 }
6795 } ;
6896
97+ Ok ( ( shader_bytes, entry_point) )
98+ }
99+ }
100+
101+ impl WgpuShader for RustComputeShader {
102+ fn create_module (
103+ & self ,
104+ device : & wgpu:: Device ,
105+ ) -> anyhow:: Result < ( wgpu:: ShaderModule , Option < String > ) > {
106+ let ( shader_bytes, entry_point) = self . spirv_bytes ( ) ?;
107+
69108 if shader_bytes. len ( ) % 4 != 0 {
70109 anyhow:: bail!( "SPIR-V binary length is not a multiple of 4" ) ;
71110 }
@@ -93,7 +132,7 @@ impl WgslComputeShader {
93132 }
94133}
95134
96- impl ComputeShader for WgslComputeShader {
135+ impl WgpuShader for WgslComputeShader {
97136 fn create_module (
98137 & self ,
99138 device : & wgpu:: Device ,
@@ -133,7 +172,7 @@ pub struct WgpuComputeTestPushConstants<S> {
133172
134173impl < S > WgpuComputeTest < S >
135174where
136- S : ComputeShader ,
175+ S : WgpuShader ,
137176{
138177 pub fn new ( shader : S , dispatch : [ u32 ; 3 ] , output_bytes : u64 ) -> Self {
139178 Self {
@@ -544,7 +583,7 @@ impl Default for RustComputeShader {
544583
545584impl < S > WgpuComputeTestMultiBuffer < S >
546585where
547- S : ComputeShader ,
586+ S : WgpuShader ,
548587{
549588 pub fn new ( shader : S , dispatch : [ u32 ; 3 ] , buffers : Vec < BufferConfig > ) -> Self {
550589 Self {
@@ -714,7 +753,7 @@ where
714753
715754impl < S > WgpuComputeTestPushConstants < S >
716755where
717- S : ComputeShader ,
756+ S : WgpuShader ,
718757{
719758 pub fn new (
720759 shader : S ,
0 commit comments