Skip to content

Commit 1ef16e0

Browse files
committed
Start implementation of context invalidation
* Add inject trait variants * Route Extract / Inject traits to the proto nodes
1 parent f40f292 commit 1ef16e0

12 files changed

Lines changed: 258 additions & 34 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub(super) fn post_process_nodes(mut custom: Vec<DocumentNodeDefinition>) -> Vec
4343
fields,
4444
description,
4545
properties,
46+
context_features,
4647
} = metadata;
4748

4849
let Some(implementations) = &node_registry.get(id) else { continue };
@@ -59,10 +60,11 @@ pub(super) fn post_process_nodes(mut custom: Vec<DocumentNodeDefinition>) -> Vec
5960
node_template: NodeTemplate {
6061
document_node: DocumentNode {
6162
inputs,
62-
call_argument: (input_type.clone()),
63+
call_argument: input_type.clone(),
6364
implementation: DocumentNodeImplementation::ProtoNode(id.clone()),
6465
visible: true,
6566
skip_deduplication: false,
67+
context_features: ContextDependencies::from(context_features.as_slice()),
6668
..Default::default()
6769
},
6870
persistent_node_metadata: DocumentNodePersistentMetadata {

node-graph/gcore/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dealloc_nodes = []
1818
graphene-core-shaders = { workspace = true, features = ["std"] }
1919

2020
# Workspace dependencies
21+
bitflags = { workspace = true }
2122
bytemuck = { workspace = true }
2223
node-macro = { workspace = true }
2324
num-traits = { workspace = true }

node-graph/gcore/src/context.rs

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,99 @@ pub trait ExtractIndex {
3131

3232
// Consider returning a slice or something like that
3333
pub trait ExtractVarArgs {
34-
// Call this lifetime 'b so it is less likely to coflict when auto generating the function signature for implementation
3534
fn vararg(&self, index: usize) -> Result<DynRef<'_>, VarArgsResult>;
3635
fn varargs_len(&self) -> Result<usize, VarArgsResult>;
3736
}
37+
3838
// Consider returning a slice or something like that
3939
pub trait CloneVarArgs: ExtractVarArgs {
4040
// fn box_clone(&self) -> Vec<DynBox>;
4141
fn arc_clone(&self) -> Option<Arc<dyn ExtractVarArgs + Send + Sync>>;
4242
}
4343

44+
// Inject* traits for providing context features to downstream nodes
45+
pub trait InjectFootprint {}
46+
pub trait InjectTime {}
47+
pub trait InjectAnimationTime {}
48+
pub trait InjectIndex {}
49+
pub trait InjectVarArgs {}
50+
51+
// Modify* marker traits for context-transparent nodes
52+
pub trait ModifyFootprint: ExtractFootprint + InjectFootprint {}
53+
pub trait ModifyTime: ExtractTime + InjectTime {}
54+
pub trait ModifyAnimationTime: ExtractAnimationTime + InjectAnimationTime {}
55+
pub trait ModifyIndex: ExtractIndex + InjectIndex {}
56+
pub trait ModifyVarArgs: ExtractVarArgs + InjectVarArgs {}
57+
4458
pub trait ExtractAll: ExtractFootprint + ExtractIndex + ExtractTime + ExtractAnimationTime + ExtractVarArgs {}
4559

4660
impl<T: ?Sized + ExtractFootprint + ExtractIndex + ExtractTime + ExtractAnimationTime + ExtractVarArgs> ExtractAll for T {}
4761

62+
impl<T: Ctx> InjectFootprint for T {}
63+
impl<T: Ctx> InjectTime for T {}
64+
impl<T: Ctx> InjectAnimationTime for T {}
65+
impl<T: Ctx> InjectVarArgs for T {}
66+
67+
// Public enum for flexible node macro codegen
68+
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
69+
pub enum ContextFeature {
70+
ExtractFootprint,
71+
ExtractTime,
72+
ExtractAnimationTime,
73+
ExtractIndex,
74+
ExtractVarArgs,
75+
InjectFootprint,
76+
InjectTime,
77+
InjectAnimationTime,
78+
InjectIndex,
79+
InjectVarArgs,
80+
}
81+
82+
// Internal bitflags for fast compiler analysis (only extract features)
83+
use bitflags::bitflags;
84+
bitflags! {
85+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, dyn_any::DynAny, serde::Serialize, serde::Deserialize, Default)]
86+
pub struct ContextFeatures: u32 {
87+
const FOOTPRINT = 1 << 0;
88+
const TIME = 1 << 1;
89+
const ANIMATION_TIME = 1 << 2;
90+
const INDEX = 1 << 3;
91+
const VAR_ARGS = 1 << 4;
92+
}
93+
}
94+
95+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, dyn_any::DynAny, serde::Serialize, serde::Deserialize, Default)]
96+
pub struct ContextDependencies {
97+
extract: ContextFeatures,
98+
inject: ContextFeatures,
99+
}
100+
101+
impl From<&[ContextFeature]> for ContextDependencies {
102+
fn from(features: &[ContextFeature]) -> Self {
103+
let mut extract = ContextFeatures::empty();
104+
let mut inject = ContextFeatures::empty();
105+
for feature in features {
106+
extract |= match feature {
107+
ContextFeature::ExtractFootprint => ContextFeatures::FOOTPRINT,
108+
ContextFeature::ExtractTime => ContextFeatures::TIME,
109+
ContextFeature::ExtractAnimationTime => ContextFeatures::ANIMATION_TIME,
110+
ContextFeature::ExtractIndex => ContextFeatures::INDEX,
111+
ContextFeature::ExtractVarArgs => ContextFeatures::VAR_ARGS,
112+
_ => ContextFeatures::empty(),
113+
};
114+
inject |= match feature {
115+
ContextFeature::InjectFootprint => ContextFeatures::FOOTPRINT,
116+
ContextFeature::InjectTime => ContextFeatures::TIME,
117+
ContextFeature::InjectAnimationTime => ContextFeatures::ANIMATION_TIME,
118+
ContextFeature::InjectIndex => ContextFeatures::INDEX,
119+
ContextFeature::InjectVarArgs => ContextFeatures::VAR_ARGS,
120+
_ => ContextFeatures::empty(),
121+
};
122+
}
123+
Self { extract, inject }
124+
}
125+
}
126+
48127
#[derive(Debug, Clone, PartialEq, Eq)]
49128
pub enum VarArgsResult {
50129
IndexOutOfBounds,
@@ -279,21 +358,29 @@ impl std::hash::Hash for OwnedContextImpl {
279358
impl OwnedContextImpl {
280359
#[track_caller]
281360
pub fn from<T: ExtractAll + CloneVarArgs>(value: T) -> Self {
282-
let footprint = value.try_footprint().copied();
283-
let index = value.try_index();
284-
let time = value.try_time();
285-
let frame_time = value.try_animation_time();
286-
let parent = match value.varargs_len() {
287-
Ok(x) if x > 0 => value.arc_clone(),
288-
_ => None,
289-
};
361+
OwnedContextImpl::from_flags(value, ContextFeatures::all())
362+
}
363+
#[track_caller]
364+
pub fn from_flags<T: ExtractAll + CloneVarArgs>(value: T, bitflags: ContextFeatures) -> Self {
365+
let footprint = bitflags.contains(ContextFeatures::FOOTPRINT).then(|| value.try_footprint().copied()).flatten();
366+
let index = bitflags.contains(ContextFeatures::INDEX).then(|| value.try_index()).flatten();
367+
let time = bitflags.contains(ContextFeatures::TIME).then(|| value.try_time()).flatten();
368+
let animation_time = bitflags.contains(ContextFeatures::ANIMATION_TIME).then(|| value.try_animation_time()).flatten();
369+
let parent = bitflags
370+
.contains(ContextFeatures::VAR_ARGS)
371+
.then(|| match value.varargs_len() {
372+
Ok(x) if x > 0 => value.arc_clone(),
373+
_ => None,
374+
})
375+
.flatten();
376+
290377
OwnedContextImpl {
291378
footprint,
292379
varargs: None,
293380
parent,
294381
index,
295382
real_time: time,
296-
animation_time: frame_time,
383+
animation_time,
297384
}
298385
}
299386
pub const fn empty() -> Self {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use crate::context::{CloneVarArgs, Context, ContextFeatures, Ctx, ExtractAll};
2+
use crate::gradient::GradientStops;
3+
use crate::raster_types::{CPU, GPU, Raster};
4+
use crate::table::Table;
5+
use crate::vector::Vector;
6+
use crate::{Graphic, OwnedContextImpl};
7+
use core::f64;
8+
use glam::{DAffine2, DVec2};
9+
use graphene_core_shaders::color::Color;
10+
11+
/// Node for filtering context features based on requirements
12+
/// This node is inserted by the compiler to "zero out" unused context parts
13+
#[node_macro::node(category("Internal"))]
14+
async fn context_modification<T>(
15+
ctx: impl Ctx + CloneVarArgs + ExtractAll,
16+
#[implementations(
17+
Context -> (),
18+
Context -> bool,
19+
Context -> u32,
20+
Context -> f32,
21+
Context -> f64,
22+
Context -> DAffine2,
23+
Context -> DVec2,
24+
Context -> Table<Vector>,
25+
Context -> Table<Graphic>,
26+
Context -> Table<Raster<CPU>>,
27+
Context -> Table<Raster<GPU>>,
28+
Context -> Table<Color>,
29+
Context -> Table<GradientStops>,
30+
)]
31+
value: impl Node<Context<'static>, Output = T>,
32+
features_to_keep: ContextFeatures,
33+
) -> T {
34+
let new_context = OwnedContextImpl::from_flags(ctx, features_to_keep);
35+
36+
value.eval(Some(new_context.into())).await
37+
}

node-graph/gcore/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod blending_nodes;
77
pub mod bounds;
88
pub mod consts;
99
pub mod context;
10+
pub mod context_modification;
1011
pub mod debug;
1112
pub mod extract_xy;
1213
pub mod generic;

node-graph/gcore/src/registry.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{Node, NodeIO, NodeIOTypes, ProtoNodeIdentifier, Type, WasmNotSend};
1+
use crate::{ContextFeature, Node, NodeIO, NodeIOTypes, ProtoNodeIdentifier, Type, WasmNotSend};
22
use dyn_any::{DynAny, StaticType};
33
use std::collections::HashMap;
44
use std::marker::PhantomData;
@@ -16,6 +16,7 @@ pub struct NodeMetadata {
1616
pub fields: Vec<FieldMetadata>,
1717
pub description: &'static str,
1818
pub properties: Option<&'static str>,
19+
pub context_features: Vec<ContextFeature>,
1920
}
2021

2122
// Translation struct between macro and definition

node-graph/graph-craft/src/document.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use glam::IVec2;
77
use graphene_core::memo::MemoHashGuard;
88
pub use graphene_core::uuid::NodeId;
99
pub use graphene_core::uuid::generate_uuid;
10-
use graphene_core::{Context, Cow, MemoHash, ProtoNodeIdentifier, Type};
10+
use graphene_core::{Context, ContextDependencies, Cow, MemoHash, ProtoNodeIdentifier, Type};
1111
use log::Metadata;
1212
use rustc_hash::FxHashMap;
1313
use std::collections::HashMap;
@@ -129,6 +129,9 @@ pub struct DocumentNode {
129129
/// The path to this node and its inputs and outputs as of when [`NodeNetwork::generate_node_paths`] was called.
130130
#[serde(skip)]
131131
pub original_location: OriginalLocation,
132+
// List of Extract and Inject annotations for the Context
133+
#[serde(default)]
134+
pub context_features: ContextDependencies,
132135
}
133136

134137
/// Represents the original location of a node input/output when [`NodeNetwork::generate_node_paths`] was called, allowing the types and errors to be derived.
@@ -163,6 +166,7 @@ impl Default for DocumentNode {
163166
visible: true,
164167
skip_deduplication: Default::default(),
165168
original_location: OriginalLocation::default(),
169+
context_features: Default::default(),
166170
}
167171
}
168172
}
@@ -231,6 +235,7 @@ impl DocumentNode {
231235
construction_args: args,
232236
original_location: self.original_location,
233237
skip_deduplication: self.skip_deduplication,
238+
context_features: self.context_features,
234239
}
235240
}
236241
}

node-graph/graph-craft/src/proto.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ pub struct ProtoNode {
132132
pub identifier: ProtoNodeIdentifier,
133133
pub original_location: OriginalLocation,
134134
pub skip_deduplication: bool,
135+
pub(crate) context_features: ContextDependencies,
135136
}
136137

137138
impl Default for ProtoNode {
@@ -142,6 +143,7 @@ impl Default for ProtoNode {
142143
call_argument: concrete!(()),
143144
original_location: OriginalLocation::default(),
144145
skip_deduplication: false,
146+
context_features: Default::default(),
145147
}
146148
}
147149
}
@@ -181,6 +183,7 @@ impl ProtoNode {
181183
..Default::default()
182184
},
183185
skip_deduplication: false,
186+
context_features: Default::default(),
184187
}
185188
}
186189

@@ -290,11 +293,63 @@ impl ProtoNetwork {
290293
(inwards_edges, id_map)
291294
}
292295

296+
/// Inserts context nullification nodes to optimize caching
297+
/// This analysis is performed after topological sorting to ensure proper dependency tracking
298+
pub fn insert_context_nullification_nodes(&mut self) -> Result<(), String> {
299+
// TODO: Implement full context flow analysis with:
300+
// 1. DFS traversal to track context feature requirements
301+
// 2. Branch convergence analysis
302+
// 3. Post-injection nullification
303+
// 4. Insert ContextModificationNode instances where beneficial
304+
305+
Ok(())
306+
}
307+
293308
/// Inserts a [`structural::ComposeNode`] for each node that has a [`ProtoNodeInput::Node`]. The compose node evaluates the first node, and then sends the result into the second node.
294309
pub fn resolve_inputs(&mut self) -> Result<(), String> {
295310
// Perform topological sort once
296311
self.reorder_ids()?;
312+
// Insert context nullification nodes after topological sort
313+
self.insert_context_nullification_nodes()?;
314+
315+
// Collect outward edges once
316+
let outwards_edges = self.collect_outwards_edges();
317+
318+
// // Iterate over nodes in topological order
319+
// for node_id in 0..=max_id {
320+
// let node_id = NodeId(node_id);
297321

322+
// let (_, node) = &mut self.nodes[node_id.0 as usize];
323+
324+
// if let ProtoNodeInput::Node(input_node_id) = node.input {
325+
// // Create a new node that composes the current node and its input node
326+
// let compose_node_id = NodeId(self.nodes.len() as u64);
327+
328+
// let (_, input_node_id_proto) = &self.nodes[input_node_id.0 as usize];
329+
330+
// let input = input_node_id_proto.input.clone();
331+
332+
// let mut path = input_node_id_proto.original_location.path.clone();
333+
// if let Some(path) = &mut path {
334+
// path.push(node_id);
335+
// }
336+
337+
// self.nodes.push((
338+
// compose_node_id,
339+
// ProtoNode {
340+
// identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode"),
341+
// construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]),
342+
// call_argument,
343+
// original_location: OriginalLocation { path, ..Default::default() },
344+
// skip_deduplication: false,
345+
// context_features: Default::default(),
346+
// },
347+
// ));
348+
349+
// self.replace_node_id(&outwards_edges, node_id, compose_node_id);
350+
// }
351+
// }
352+
self.reorder_ids()?;
298353
Ok(())
299354
}
300355

node-graph/node-macro/src/codegen.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
4141
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
4242
let input_ident = &input.pat_ident;
4343

44+
let context_features = &input.context_features;
45+
4446
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
4547
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
4648

@@ -242,7 +244,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
242244
#name: #graphene_core::Node<'n, #input_type, Output = #fut_ident > + #graphene_core::WasmNotSync
243245
)
244246
}
245-
(ParsedFieldType::Node { .. }, false) => unreachable!(),
247+
(ParsedFieldType::Node { .. }, false) => unreachable!("Found node which takes an impl Node<> input but is not async"),
246248
});
247249
}
248250
let where_clause = where_clause.clone().unwrap_or(WhereClause {
@@ -327,7 +329,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
327329
mod #mod_name {
328330
use super::*;
329331
use #graphene_core as gcore;
330-
use gcore::{Node, NodeIOTypes, concrete, fn_type, fn_type_fut, future, ProtoNodeIdentifier, WasmNotSync, NodeIO};
332+
use gcore::{Node, NodeIOTypes, concrete, fn_type, fn_type_fut, future, ProtoNodeIdentifier, WasmNotSync, NodeIO, ContextFeature};
331333
use gcore::value::ClonedNode;
332334
use gcore::ops::TypeNode;
333335
use gcore::registry::{NodeMetadata, FieldMetadata, NODE_REGISTRY, NODE_METADATA, DynAnyNode, DowncastBothNode, DynFuture, TypeErasedBox, PanicNode, RegistryValueSource, RegistryWidgetOverride};
@@ -362,6 +364,7 @@ pub(crate) fn generate_node_code(parsed: &ParsedNodeFn) -> syn::Result<TokenStre
362364
category: #category,
363365
description: #description,
364366
properties: #properties,
367+
context_features: vec![#(ContextFeature::#context_features,)*],
365368
fields: vec![
366369
#(
367370
FieldMetadata {

0 commit comments

Comments
 (0)