|
| 1 | +use crate::Context; |
| 2 | +use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime}; |
| 3 | +use futures::lock::Mutex; |
| 4 | +use graphene_core::raster_types::{GPU, Raster}; |
| 5 | +use graphene_core::table::{Table, TableRow}; |
| 6 | +use std::borrow::Cow; |
| 7 | +use std::collections::HashMap; |
| 8 | +use wgpu::{ |
| 9 | + BindGroupDescriptor, BindGroupEntry, BindingResource, ColorTargetState, Face, FragmentState, FrontFace, LoadOp, Operations, PolygonMode, PrimitiveState, PrimitiveTopology, |
| 10 | + RenderPassColorAttachment, RenderPassDescriptor, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, StoreOp, TextureDescriptor, TextureDimension, TextureFormat, |
| 11 | + TextureViewDescriptor, VertexState, |
| 12 | +}; |
| 13 | + |
| 14 | +pub struct PerPixelAdjustShaderRuntime { |
| 15 | + // TODO: PerPixelAdjustGraphicsPipeline already contains the key as `name` |
| 16 | + pipeline_cache: Mutex<HashMap<String, PerPixelAdjustGraphicsPipeline>>, |
| 17 | +} |
| 18 | + |
| 19 | +impl PerPixelAdjustShaderRuntime { |
| 20 | + pub fn new() -> Self { |
| 21 | + Self { |
| 22 | + pipeline_cache: Mutex::new(HashMap::new()), |
| 23 | + } |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +impl ShaderRuntime { |
| 28 | + pub async fn run_per_pixel_adjust(&self, input: Table<Raster<GPU>>, info: &PerPixelAdjustInfo<'_>) -> Table<Raster<GPU>> { |
| 29 | + let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await; |
| 30 | + let pipeline = cache |
| 31 | + .entry(info.fragment_shader_name.to_owned()) |
| 32 | + .or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &info)); |
| 33 | + pipeline.run(&self.context, input) |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +pub struct PerPixelAdjustInfo<'a> { |
| 38 | + shader_wgsl: &'a str, |
| 39 | + fragment_shader_name: &'a str, |
| 40 | +} |
| 41 | + |
| 42 | +pub struct PerPixelAdjustGraphicsPipeline { |
| 43 | + name: String, |
| 44 | + pipeline: wgpu::RenderPipeline, |
| 45 | +} |
| 46 | + |
| 47 | +impl PerPixelAdjustGraphicsPipeline { |
| 48 | + pub fn new(context: &Context, info: &PerPixelAdjustInfo) -> Self { |
| 49 | + let device = &context.device; |
| 50 | + let name = info.fragment_shader_name.to_owned(); |
| 51 | + let shader_module = device.create_shader_module(ShaderModuleDescriptor { |
| 52 | + label: Some(&format!("PerPixelAdjust {} wgsl shader", name)), |
| 53 | + source: ShaderSource::Wgsl(Cow::Borrowed(info.shader_wgsl)), |
| 54 | + }); |
| 55 | + let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor { |
| 56 | + label: Some(&format!("PerPixelAdjust {} Pipeline", name)), |
| 57 | + layout: None, |
| 58 | + vertex: VertexState { |
| 59 | + module: &shader_module, |
| 60 | + entry_point: Some(FULLSCREEN_VERTEX_SHADER_NAME), |
| 61 | + compilation_options: Default::default(), |
| 62 | + buffers: &[], |
| 63 | + }, |
| 64 | + primitive: PrimitiveState { |
| 65 | + topology: PrimitiveTopology::TriangleList, |
| 66 | + strip_index_format: None, |
| 67 | + front_face: FrontFace::Ccw, |
| 68 | + cull_mode: Some(Face::Back), |
| 69 | + unclipped_depth: false, |
| 70 | + polygon_mode: PolygonMode::Fill, |
| 71 | + conservative: false, |
| 72 | + }, |
| 73 | + depth_stencil: None, |
| 74 | + multisample: Default::default(), |
| 75 | + fragment: Some(FragmentState { |
| 76 | + module: &shader_module, |
| 77 | + entry_point: Some(&name), |
| 78 | + compilation_options: Default::default(), |
| 79 | + targets: &[Some(ColorTargetState { |
| 80 | + format: TextureFormat::Rgba32Float, |
| 81 | + blend: None, |
| 82 | + write_mask: Default::default(), |
| 83 | + })], |
| 84 | + }), |
| 85 | + multiview: None, |
| 86 | + cache: None, |
| 87 | + }); |
| 88 | + Self { pipeline, name } |
| 89 | + } |
| 90 | + |
| 91 | + pub fn run(&self, context: &Context, input: Table<Raster<GPU>>) -> Table<Raster<GPU>> { |
| 92 | + let device = &context.device; |
| 93 | + let name = self.name.as_str(); |
| 94 | + |
| 95 | + let mut cmd = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("gpu_invert") }); |
| 96 | + let out = input |
| 97 | + .iter() |
| 98 | + .map(|instance| { |
| 99 | + let tex_in = &instance.element.texture; |
| 100 | + let view_in = tex_in.create_view(&TextureViewDescriptor::default()); |
| 101 | + let format = tex_in.format(); |
| 102 | + |
| 103 | + let bind_group = device.create_bind_group(&BindGroupDescriptor { |
| 104 | + label: Some(&format!("{name} bind group")), |
| 105 | + // `get_bind_group_layout` allocates unnecessary memory, we could create it manually to not do that |
| 106 | + layout: &self.pipeline.get_bind_group_layout(0), |
| 107 | + entries: &[BindGroupEntry { |
| 108 | + binding: 0, |
| 109 | + resource: BindingResource::TextureView(&view_in), |
| 110 | + }], |
| 111 | + }); |
| 112 | + |
| 113 | + let tex_out = device.create_texture(&TextureDescriptor { |
| 114 | + label: Some(&format!("{name} texture out")), |
| 115 | + size: tex_in.size(), |
| 116 | + mip_level_count: 1, |
| 117 | + sample_count: 1, |
| 118 | + dimension: TextureDimension::D2, |
| 119 | + format, |
| 120 | + usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST | wgpu::TextureUsages::COPY_SRC | wgpu::TextureUsages::RENDER_ATTACHMENT, |
| 121 | + view_formats: &[format], |
| 122 | + }); |
| 123 | + |
| 124 | + let view_out = tex_out.create_view(&TextureViewDescriptor::default()); |
| 125 | + let mut rp = cmd.begin_render_pass(&RenderPassDescriptor { |
| 126 | + label: Some(&format!("{name} render pipeline")), |
| 127 | + color_attachments: &[Some(RenderPassColorAttachment { |
| 128 | + view: &view_out, |
| 129 | + resolve_target: None, |
| 130 | + ops: Operations { |
| 131 | + // should be dont_care but wgpu doesn't expose that |
| 132 | + load: LoadOp::Clear(wgpu::Color::BLACK), |
| 133 | + store: StoreOp::Store, |
| 134 | + }, |
| 135 | + })], |
| 136 | + depth_stencil_attachment: None, |
| 137 | + timestamp_writes: None, |
| 138 | + occlusion_query_set: None, |
| 139 | + }); |
| 140 | + rp.set_pipeline(&self.pipeline); |
| 141 | + rp.set_bind_group(0, Some(&bind_group), &[]); |
| 142 | + rp.draw(0..3, 0..1); |
| 143 | + |
| 144 | + TableRow { |
| 145 | + element: Raster::new(GPU { texture: tex_out }), |
| 146 | + transform: *instance.transform, |
| 147 | + alpha_blending: *instance.alpha_blending, |
| 148 | + source_node_id: *instance.source_node_id, |
| 149 | + } |
| 150 | + }) |
| 151 | + .collect::<Table<_>>(); |
| 152 | + context.queue.submit([cmd.finish()]); |
| 153 | + out |
| 154 | + } |
| 155 | +} |
0 commit comments