Skip to content

Commit 81640a0

Browse files
committed
difftest runner: make passthrough optional
1 parent f07aa39 commit 81640a0

1 file changed

Lines changed: 29 additions & 12 deletions

File tree

tests/difftests/tests/lib/src/scaffold/shader/rust_gpu_shader.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ use spirv_builder::{ModuleResult, SpirvBuilder};
44
use std::borrow::Cow;
55
use std::path::PathBuf;
66
use std::{env, fs};
7+
use wgpu::ShaderSource;
78

89
/// A compute shader written in Rust compiled with spirv-builder.
910
pub struct RustComputeShader {
1011
pub builder: SpirvBuilder,
12+
pub passthrough: bool,
1113
}
1214

1315
impl RustComputeShader {
@@ -22,13 +24,21 @@ impl RustComputeShader {
2224
.multimodule(false)
2325
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
2426
.preserve_bindings(true),
27+
passthrough: false,
2528
}
2629
}
2730

2831
pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self {
2932
self.builder.capabilities.push(capability);
3033
self
3134
}
35+
36+
pub fn passthrough(mut self) -> Self {
37+
Self {
38+
passthrough: true,
39+
..self
40+
}
41+
}
3242
}
3343

3444
impl SpirvShader for RustComputeShader {
@@ -70,18 +80,25 @@ impl WgpuShader for RustComputeShader {
7080
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
7181
}
7282
let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec();
73-
let module = unsafe {
74-
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
75-
entry_point: entry_point.clone(),
76-
label: Some("Compute Shader"),
77-
num_workgroups: (0, 0, 0),
78-
runtime_checks: Default::default(),
79-
spirv: Some(Cow::Owned(shader_words)),
80-
dxil: None,
81-
msl: None,
82-
hlsl: None,
83-
glsl: None,
84-
wgsl: None,
83+
let module = if self.passthrough {
84+
unsafe {
85+
device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough {
86+
entry_point: entry_point.clone(),
87+
label: Some("Rust-GPU Compute Shader"),
88+
num_workgroups: (0, 0, 0),
89+
runtime_checks: Default::default(),
90+
spirv: Some(Cow::Owned(shader_words)),
91+
dxil: None,
92+
msl: None,
93+
hlsl: None,
94+
glsl: None,
95+
wgsl: None,
96+
})
97+
}
98+
} else {
99+
device.create_shader_module(wgpu::ShaderModuleDescriptor {
100+
label: Some("Rust-GPU Compute Shader"),
101+
source: ShaderSource::SpirV(Cow::Owned(shader_words)),
85102
})
86103
};
87104
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {

0 commit comments

Comments
 (0)