Skip to content

Commit cdd5653

Browse files
committed
shader-rt: initial
1 parent 1f486e4 commit cdd5653

10 files changed

Lines changed: 260 additions & 7 deletions

File tree

node-graph/gcore/src/raster_types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ mod gpu {
137137

138138
#[derive(Clone, Debug, PartialEq, Hash)]
139139
pub struct GPU {
140-
texture: wgpu::Texture,
140+
pub texture: wgpu::Texture,
141141
}
142142

143143
impl Sealed for Raster<GPU> {}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use glam::{Vec2, Vec4};
2+
use spirv_std::spirv;
3+
4+
/// webgpu NDC is like OpenGL: (-1.0 .. 1.0, -1.0 .. 1.0, 0.0 .. 1.0)
5+
/// https://www.w3.org/TR/webgpu/#coordinate-systems
6+
const FULLSCREEN_VERTICES: [Vec2; 3] = [Vec2::new(-1., -1.), Vec2::new(-1., 3.), Vec2::new(3., -1.)];
7+
8+
#[spirv(vertex)]
9+
pub fn fullscreen_vertex(#[spirv(vertex_index)] vertex_index: u32, #[spirv(position)] gl_position: &mut Vec4) {
10+
// broken on edition 2024 branch
11+
// let vertex = unsafe { *FULLSCREEN_VERTICES.index_unchecked(vertex_index as usize) };
12+
let vertex = FULLSCREEN_VERTICES[vertex_index as usize];
13+
*gl_position = Vec4::from((vertex, 0., 1.));
14+
}

node-graph/graster-nodes/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod adjust;
44
pub mod adjustments;
55
pub mod blending_nodes;
66
pub mod cubic_spline;
7+
pub mod fullscreen_vertex;
78

89
#[cfg(feature = "std")]
910
pub mod curve;

node-graph/node-macro/src/codegen.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,12 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
295295

296296
let cfg = crate::shader_nodes::modify_cfg(attributes);
297297
let node_input_accessor = generate_node_input_references(parsed, fn_generics, &field_idents, &graphene_core, &identifier, &cfg);
298-
let shader_entry_point = attributes.shader_node.as_ref().map(|n| n.codegen_shader_entry_point(parsed)).unwrap_or(Ok(TokenStream::new()))?;
298+
let (shader_entry_point, shader_gpu_node) = attributes
299+
.shader_node
300+
.as_ref()
301+
.map::<syn::Result<_>, _>(|n| Ok((n.codegen_shader_entry_point(parsed)?, n.codegen_gpu_node(parsed)?)))
302+
.unwrap_or(Ok((TokenStream::new(), TokenStream::new())))?;
303+
299304
Ok(quote! {
300305
/// Underlying implementation for [#struct_name]
301306
#[inline]
@@ -387,6 +392,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
387392
}
388393

389394
#shader_entry_point
395+
396+
#shader_gpu_node
390397
})
391398
}
392399

node-graph/node-macro/src/parsing.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub(crate) struct ParsedNodeFn {
3939
pub(crate) description: String,
4040
}
4141

42-
#[derive(Debug, Default)]
42+
#[derive(Debug, Default, Clone)]
4343
pub(crate) struct NodeFnAttributes {
4444
pub(crate) category: Option<LitStr>,
4545
pub(crate) display_name: Option<LitStr>,

node-graph/node-macro/src/shader_nodes/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub fn modify_cfg(attributes: &NodeFnAttributes) -> TokenStream {
1919
}
2020
}
2121

22-
#[derive(Debug, VariantNames)]
22+
#[derive(Debug, Clone, VariantNames)]
2323
pub(crate) enum ShaderNodeType {
2424
PerPixelAdjust(PerPixelAdjust),
2525
}
@@ -36,6 +36,7 @@ impl Parse for ShaderNodeType {
3636

3737
pub trait CodegenShaderEntryPoint {
3838
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream>;
39+
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream>;
3940
}
4041

4142
impl CodegenShaderEntryPoint for ShaderNodeType {
@@ -48,4 +49,10 @@ impl CodegenShaderEntryPoint for ShaderNodeType {
4849
ShaderNodeType::PerPixelAdjust(x) => x.codegen_shader_entry_point(parsed),
4950
}
5051
}
52+
53+
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
54+
match self {
55+
ShaderNodeType::PerPixelAdjust(x) => x.codegen_gpu_node(parsed),
56+
}
57+
}
5158
}

node-graph/node-macro/src/shader_nodes/per_pixel_adjust.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
use crate::parsing::{ParsedFieldType, ParsedNodeFn, RegularParsedField};
1+
use crate::parsing::{Input, NodeFnAttributes, ParsedField, ParsedFieldType, ParsedNodeFn, RegularParsedField};
22
use crate::shader_nodes::CodegenShaderEntryPoint;
3+
use convert_case::{Case, Casing};
34
use proc_macro2::{Ident, TokenStream};
45
use quote::{ToTokens, format_ident, quote};
56
use std::borrow::Cow;
67
use syn::parse::{Parse, ParseStream};
8+
use syn::{Path, Type, TypePath};
79

8-
#[derive(Debug)]
10+
#[derive(Debug, Clone)]
911
pub struct PerPixelAdjust {}
1012

1113
impl Parse for PerPixelAdjust {
@@ -17,7 +19,7 @@ impl Parse for PerPixelAdjust {
1719
impl CodegenShaderEntryPoint for PerPixelAdjust {
1820
fn codegen_shader_entry_point(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
1921
let fn_name = &parsed.fn_name;
20-
let gpu_mod = format_ident!("{}_gpu", parsed.fn_name);
22+
let gpu_mod = format_ident!("{}_gpu_entry_point", parsed.fn_name);
2123
let spirv_image_ty = quote!(Image2d);
2224

2325
// bindings for images start at 1
@@ -96,6 +98,52 @@ impl CodegenShaderEntryPoint for PerPixelAdjust {
9698
}
9799
})
98100
}
101+
102+
fn codegen_gpu_node(&self, parsed: &ParsedNodeFn) -> syn::Result<TokenStream> {
103+
let fn_name = format_ident!("{}_gpu", parsed.fn_name);
104+
let struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
105+
let mod_name = fn_name.clone();
106+
107+
let fields = parsed
108+
.fields
109+
.iter()
110+
.map(|f| match &f.ty {
111+
ParsedFieldType::Regular(reg) => Ok(ParsedField {
112+
ty: ParsedFieldType::Regular(RegularParsedField { gpu_image: false, ..reg.clone() }),
113+
..f.clone()
114+
}),
115+
ParsedFieldType::Node { .. } => Err(syn::Error::new_spanned(&f.pat_ident, "PerPixelAdjust shader nodes cannot accept other nodes as generics")),
116+
})
117+
.collect::<syn::Result<_>>()?;
118+
let body = quote! {};
119+
120+
crate::codegen::generate_node_code(&ParsedNodeFn {
121+
vis: parsed.vis.clone(),
122+
attributes: NodeFnAttributes {
123+
shader_node: None,
124+
..parsed.attributes.clone()
125+
},
126+
fn_name,
127+
struct_name,
128+
mod_name,
129+
fn_generics: vec![],
130+
where_clause: None,
131+
input: Input {
132+
pat_ident: parsed.input.pat_ident.clone(),
133+
ty: Type::Path(TypePath {
134+
path: Path::from(format_ident!("Ctx")),
135+
qself: None,
136+
}),
137+
implementations: Default::default(),
138+
},
139+
output_type: parsed.output_type.clone(),
140+
is_async: true,
141+
fields,
142+
body,
143+
crate_name: parsed.crate_name.clone(),
144+
description: "".to_string(),
145+
})
146+
}
99147
}
100148

101149
struct Param<'a> {

node-graph/wgpu-executor/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod context;
2+
pub mod shader_runtime;
23
pub mod texture_upload;
34

45
use anyhow::Result;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use crate::Context;
2+
use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime;
3+
4+
pub mod per_pixel_adjust_runtime;
5+
6+
pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex";
7+
8+
pub struct ShaderRuntime {
9+
context: Context,
10+
per_pixel_adjust: PerPixelAdjustShaderRuntime,
11+
}
12+
13+
impl ShaderRuntime {
14+
pub fn new(context: &Context) -> Self {
15+
Self {
16+
context: context.clone(),
17+
per_pixel_adjust: PerPixelAdjustShaderRuntime::new(),
18+
}
19+
}
20+
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use crate::Context;
2+
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
3+
use futures::lock::Mutex;
4+
use graphene_core::raster_types::{GPU, Raster};
5+
use graphene_core::table::{Table, TableRow};
6+
use std::borrow::Cow;
7+
use std::collections::HashMap;
8+
use wgpu::{
9+
BindGroupDescriptor, BindGroupEntry, BindingResource, ColorTargetState, Face, FragmentState, FrontFace, LoadOp, Operations, PolygonMode, PrimitiveState, PrimitiveTopology,
10+
RenderPassColorAttachment, RenderPassDescriptor, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, StoreOp, TextureDescriptor, TextureDimension, TextureFormat,
11+
TextureViewDescriptor, VertexState,
12+
};
13+
14+
pub struct PerPixelAdjustShaderRuntime {
15+
// TODO: PerPixelAdjustGraphicsPipeline already contains the key as `name`
16+
pipeline_cache: Mutex<HashMap<String, PerPixelAdjustGraphicsPipeline>>,
17+
}
18+
19+
impl PerPixelAdjustShaderRuntime {
20+
pub fn new() -> Self {
21+
Self {
22+
pipeline_cache: Mutex::new(HashMap::new()),
23+
}
24+
}
25+
}
26+
27+
impl ShaderRuntime {
28+
pub async fn run_per_pixel_adjust(&self, input: Table<Raster<GPU>>, info: &PerPixelAdjustInfo<'_>) -> Table<Raster<GPU>> {
29+
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
30+
let pipeline = cache
31+
.entry(info.fragment_shader_name.to_owned())
32+
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &info));
33+
pipeline.run(&self.context, input)
34+
}
35+
}
36+
37+
pub struct PerPixelAdjustInfo<'a> {
38+
shader_wgsl: &'a str,
39+
fragment_shader_name: &'a str,
40+
}
41+
42+
pub struct PerPixelAdjustGraphicsPipeline {
43+
name: String,
44+
pipeline: wgpu::RenderPipeline,
45+
}
46+
47+
impl PerPixelAdjustGraphicsPipeline {
48+
pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self {
49+
let device = &context.device;
50+
let name = info.fragment_shader_name.to_owned();
51+
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
52+
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
53+
source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)),
54+
});
55+
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
56+
label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
57+
layout: None,
58+
vertex: VertexState {
59+
module: &shader_module,
60+
entry_point: Some(FULLSCREEN_VERTEX_SHADER_NAME),
61+
compilation_options: Default::default(),
62+
buffers: &[],
63+
},
64+
primitive: PrimitiveState {
65+
topology: PrimitiveTopology::TriangleList,
66+
strip_index_format: None,
67+
front_face: FrontFace::Ccw,
68+
cull_mode: Some(Face::Back),
69+
unclipped_depth: false,
70+
polygon_mode: PolygonMode::Fill,
71+
conservative: false,
72+
},
73+
depth_stencil: None,
74+
multisample: Default::default(),
75+
fragment: Some(FragmentState {
76+
module: &shader_module,
77+
entry_point: Some(&name),
78+
compilation_options: Default::default(),
79+
targets: &[Some(ColorTargetState {
80+
format: TextureFormat::Rgba32Float,
81+
blend: None,
82+
write_mask: Default::default(),
83+
})],
84+
}),
85+
multiview: None,
86+
cache: None,
87+
});
88+
Self { pipeline, name }
89+
}
90+
91+
pub fn run(&self, context: &Context, input: Table<Raster<GPU>>) -> Table<Raster<GPU>> {
92+
let device = &context.device;
93+
let name = self.name.as_str();
94+
95+
let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("gpu_invert") });
96+
let out = input
97+
.iter()
98+
.map(|instance| {
99+
let tex_in = &instance.element.texture;
100+
let view_in = tex_in.create_view(&TextureViewDescriptor::default());
101+
let format = tex_in.format();
102+
103+
let bind_group = device.create_bind_group(&BindGroupDescriptor {
104+
label: Some(&format!("{name} bind group")),
105+
// `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that
106+
layout: &self.pipeline.get_bind_group_layout(0),
107+
entries: &[BindGroupEntry {
108+
binding: 0,
109+
resource: BindingResource::TextureView(&view_in),
110+
}],
111+
});
112+
113+
let tex_out = device.create_texture(&TextureDescriptor {
114+
label: Some(&format!("{name} texture out")),
115+
size: tex_in.size(),
116+
mip_level_count: 1,
117+
sample_count: 1,
118+
dimension: TextureDimension::D2,
119+
format,
120+
usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST | wgpu::TextureUsages::COPY_SRC | wgpu::TextureUsages::RENDER_ATTACHMENT,
121+
view_formats: &[format],
122+
});
123+
124+
let view_out = tex_out.create_view(&TextureViewDescriptor::default());
125+
let mut rp = cmd.begin_render_pass(&RenderPassDescriptor {
126+
label: Some(&format!("{name} render pipeline")),
127+
color_attachments: &[Some(RenderPassColorAttachment {
128+
view: &view_out,
129+
resolve_target: None,
130+
ops: Operations {
131+
// should be dont_care but wgpu doesn't expose that
132+
load: LoadOp::Clear(wgpu::Color::BLACK),
133+
store: StoreOp::Store,
134+
},
135+
})],
136+
depth_stencil_attachment: None,
137+
timestamp_writes: None,
138+
occlusion_query_set: None,
139+
});
140+
rp.set_pipeline(&self.pipeline);
141+
rp.set_bind_group(0, Some(&bind_group), &[]);
142+
rp.draw(0..3, 0..1);
143+
144+
TableRow {
145+
element: Raster::new(GPU { texture: tex_out }),
146+
transform: *instance.transform,
147+
alpha_blending: *instance.alpha_blending,
148+
source_node_id: *instance.source_node_id,
149+
}
150+
})
151+
.collect::<Table<_>>();
152+
context.queue.submit([cmd.finish()]);
153+
out
154+
}
155+
}

0 commit comments

Comments
 (0)