Skip to content

Commit 403f139

Browse files
committed
Add generated Visit trait for AST node traversal
Enable walking the AST by generating a Visit trait with per-node visitor methods. It uses double dispatch to route each node type to its corresponding visitor method. This avoids consumers needing to manually match on Node variants and allows overriding specific visits while inheriting default behavior for others.
1 parent 3e36dd0 commit 403f139

1 file changed

Lines changed: 70 additions & 8 deletions

File tree

rust/ruby-rbs/build.rs

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ struct Node {
2727
fields: Option<Vec<NodeField>>,
2828
}
2929

30+
impl Node {
31+
fn variant_name(&self) -> &str {
32+
self.rust_name
33+
.strip_suffix("Node")
34+
.unwrap_or(&self.rust_name)
35+
}
36+
}
37+
3038
fn main() -> Result<(), Box<dyn Error>> {
3139
let config_path = Path::new(env!("CARGO_MANIFEST_DIR"))
3240
.join("../../config.yml")
@@ -62,11 +70,12 @@ fn main() -> Result<(), Box<dyn Error>> {
6270
enum CIdentifier {
6371
Type, // foo_bar_t
6472
Constant, // FOO_BAR
73+
Method, // visit_foo_bar
6574
}
6675

6776
fn convert_name(name: &str, identifier: CIdentifier) -> String {
6877
let type_name = name.replace("::", "_");
69-
let lowercase = matches!(identifier, CIdentifier::Type);
78+
let lowercase = matches!(identifier, CIdentifier::Type | CIdentifier::Method);
7079
let mut out = String::new();
7180
let mut prev_is_lower = false;
7281

@@ -94,12 +103,67 @@ fn convert_name(name: &str, identifier: CIdentifier) -> String {
94103
}
95104
}
96105

97-
if lowercase {
106+
if matches!(identifier, CIdentifier::Type) {
98107
out.push_str("_t");
99108
}
100109
out
101110
}
102111

112+
fn write_visit_trait(file: &mut File, config: &Config) -> Result<(), Box<dyn std::error::Error>> {
113+
writeln!(file, "/// A trait for traversing the AST using a visitor")?;
114+
writeln!(file, "pub trait Visit {{")?;
115+
writeln!(
116+
file,
117+
" /// Visit any node of the AST. Generally used to continue traversal"
118+
)?;
119+
writeln!(file, " fn visit(&mut self, node: &Node) {{")?;
120+
writeln!(file, " match node {{")?;
121+
122+
for node in &config.nodes {
123+
let node_variant_name = node.variant_name();
124+
let method_name = convert_name(node_variant_name, CIdentifier::Method);
125+
126+
writeln!(file, " Node::{}(it) => {{", node_variant_name)?;
127+
writeln!(file, " self.visit_{}_node(it);", method_name,)?;
128+
writeln!(file, " }}")?;
129+
}
130+
131+
writeln!(file, " }}")?;
132+
writeln!(file, " }}")?;
133+
134+
for node in &config.nodes {
135+
let node_variant_name = node.variant_name();
136+
let method_name = convert_name(node_variant_name, CIdentifier::Method);
137+
138+
writeln!(file)?;
139+
writeln!(
140+
file,
141+
" fn visit_{}_node(&mut self, node: &{}Node) {{",
142+
method_name, node_variant_name
143+
)?;
144+
writeln!(file, " visit_{}_node(self, node);", method_name)?;
145+
writeln!(file, " }}")?;
146+
}
147+
writeln!(file, "}}")?;
148+
writeln!(file)?;
149+
150+
for node in &config.nodes {
151+
let node_variant_name = node.variant_name();
152+
let method_name = convert_name(node_variant_name, CIdentifier::Method);
153+
154+
writeln!(file, "#[allow(unused_variables)]")?; // TODO: Remove this once all nodes that need visitor are implemented
155+
writeln!(
156+
file,
157+
"pub fn visit_{}_node<V: Visit + ?Sized>(visitor: &mut V, node: &{}Node) {{",
158+
method_name, node_variant_name
159+
)?;
160+
writeln!(file, "}}")?;
161+
writeln!(file)?;
162+
}
163+
164+
Ok(())
165+
}
166+
103167
fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
104168
let out_dir = env::var("OUT_DIR")?;
105169
let dest_path = Path::new(&out_dir).join("bindings.rs");
@@ -291,18 +355,13 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
291355
)?;
292356
writeln!(file, " match unsafe {{ (*node).type_ }} {{")?;
293357
for node in &config.nodes {
294-
let variant_name = node
295-
.rust_name
296-
.strip_suffix("Node")
297-
.unwrap_or(&node.rust_name);
298-
299358
let enum_name = convert_name(&node.name, CIdentifier::Constant);
300359

301360
writeln!(
302361
file,
303362
" rbs_node_type::{} => Self::{}({} {{ parser, pointer: node.cast::<{}>() }}),",
304363
enum_name,
305-
variant_name,
364+
node.variant_name(),
306365
node.rust_name,
307366
convert_name(&node.name, CIdentifier::Type)
308367
)?;
@@ -314,6 +373,9 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
314373
writeln!(file, " }}")?;
315374
writeln!(file, " }}")?;
316375
writeln!(file, "}}")?;
376+
writeln!(file)?;
377+
378+
write_visit_trait(&mut file, config)?;
317379

318380
Ok(())
319381
}

0 commit comments

Comments
 (0)