Skip to content

Commit 7e2d79d

Browse files
committed
Generate child node traversal in visitor functions (#72)
The Visit trait added in #69 provided the scaffolding for AST traversal, but the visitor functions were empty stubs that didn't recurse into children nodes. Without this, the visitor pattern is incomplete as we'd have to manually write traversal logic every time we want to walk the tree. This commit adds the generation of visitor functions for child node traversal. We handle four field types: - `rbs_node`: single child node - `rbs_node_list`: list of child nodes - `rbs_hash`: key-value pairs of nodes - Wrapper types (`rbs_type_name`, `rbs_namespace`, etc): each with its own visitor method Each case handles optional fields to safely skip NULL pointers
1 parent cf511db commit 7e2d79d

2 files changed

Lines changed: 235 additions & 1 deletion

File tree

rust/ruby-rbs/build.rs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,135 @@ fn write_visit_trait(file: &mut File, config: &Config) -> Result<(), Box<dyn std
185185
writeln!(file, "}}")?;
186186
writeln!(file)?;
187187

188+
// Map C field types (e.g. `rbs_type_name`) to the corresponding
189+
// visitor method name (e.g. `type_name` -> `visit_type_name_node`).
190+
let visitor_method_names: std::collections::HashMap<String, String> = config
191+
.nodes
192+
.iter()
193+
.map(|node| {
194+
let c_type = convert_name(&node.name, CIdentifier::Type);
195+
let c_type = c_type.strip_suffix("_t").unwrap_or(&c_type).to_string();
196+
let method = convert_name(node.variant_name(), CIdentifier::Method);
197+
(c_type, method)
198+
})
199+
.collect();
200+
201+
let is_visitable = |c_type: &str| -> bool {
202+
matches!(c_type, "rbs_node" | "rbs_node_list" | "rbs_hash")
203+
|| visitor_method_names.contains_key(c_type)
204+
};
205+
188206
for node in &config.nodes {
189207
let node_variant_name = node.variant_name();
190208
let method_name = convert_name(node_variant_name, CIdentifier::Method);
191209

192-
writeln!(file, "#[allow(unused_variables)]")?; // TODO: Remove this once all nodes that need visitor are implemented
210+
let has_visitable_fields = node
211+
.fields
212+
.iter()
213+
.flatten()
214+
.any(|field| is_visitable(&field.c_type));
215+
216+
if !has_visitable_fields {
217+
// If there's nothing to visit in this node, write the empty method with
218+
// underscored parameters, and skip to the next iteration
219+
writeln!(
220+
file,
221+
"pub fn visit_{method_name}_node<V: Visit + ?Sized>(_visitor: &mut V, _node: &{node_variant_name}Node) {{}}"
222+
)?;
223+
224+
continue;
225+
}
226+
193227
writeln!(
194228
file,
195229
"pub fn visit_{}_node<V: Visit + ?Sized>(visitor: &mut V, node: &{}Node) {{",
196230
method_name, node_variant_name
197231
)?;
232+
233+
if let Some(fields) = &node.fields {
234+
for field in fields {
235+
let field_method_name = if field.name == "type" {
236+
"type_"
237+
} else {
238+
field.name.as_str()
239+
};
240+
241+
match field.c_type.as_str() {
242+
"rbs_node" => {
243+
if field.optional {
244+
writeln!(
245+
file,
246+
" if let Some(item) = node.{field_method_name}() {{"
247+
)?;
248+
writeln!(file, " visitor.visit(&item);")?;
249+
writeln!(file, " }}")?;
250+
} else {
251+
writeln!(file, " visitor.visit(&node.{field_method_name}());")?;
252+
}
253+
}
254+
255+
"rbs_node_list" => {
256+
if field.optional {
257+
writeln!(
258+
file,
259+
" if let Some(list) = node.{field_method_name}() {{"
260+
)?;
261+
writeln!(file, " for item in list.iter() {{")?;
262+
writeln!(file, " visitor.visit(&item);")?;
263+
writeln!(file, " }}")?;
264+
writeln!(file, " }}")?;
265+
} else {
266+
writeln!(file, " for item in node.{field_method_name}().iter() {{")?;
267+
writeln!(file, " visitor.visit(&item);")?;
268+
writeln!(file, " }}")?;
269+
}
270+
}
271+
272+
"rbs_hash" => {
273+
if field.optional {
274+
writeln!(
275+
file,
276+
" if let Some(hash) = node.{field_method_name}() {{"
277+
)?;
278+
writeln!(file, " for (key, value) in hash.iter() {{")?;
279+
writeln!(file, " visitor.visit(&key);")?;
280+
writeln!(file, " visitor.visit(&value);")?;
281+
writeln!(file, " }}")?;
282+
writeln!(file, " }}")?;
283+
} else {
284+
writeln!(
285+
file,
286+
" for (key, value) in node.{field_method_name}().iter() {{"
287+
)?;
288+
writeln!(file, " visitor.visit(&key);")?;
289+
writeln!(file, " visitor.visit(&value);")?;
290+
writeln!(file, " }}")?;
291+
}
292+
}
293+
294+
_ => {
295+
if let Some(visit_method_name) = visitor_method_names.get(&field.c_type) {
296+
if field.optional {
297+
writeln!(
298+
file,
299+
" if let Some(item) = node.{field_method_name}() {{"
300+
)?;
301+
writeln!(
302+
file,
303+
" visitor.visit_{visit_method_name}_node(&item);"
304+
)?;
305+
writeln!(file, " }}")?;
306+
} else {
307+
writeln!(
308+
file,
309+
" visitor.visit_{visit_method_name}_node(&node.{field_method_name}());"
310+
)?;
311+
}
312+
}
313+
}
314+
}
315+
}
316+
}
198317
writeln!(file, "}}")?;
199318
writeln!(file)?;
200319
}
@@ -226,6 +345,12 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
226345
writeln!(file, "}}\n")?;
227346

228347
writeln!(file, "impl {} {{", node.rust_name)?;
348+
writeln!(file, " /// Converts this node to a generic node.")?;
349+
writeln!(file, " #[must_use]")?;
350+
writeln!(file, " pub fn as_node(self) -> Node {{")?;
351+
writeln!(file, " Node::{}(self)", node.variant_name())?;
352+
writeln!(file, " }}")?;
353+
229354
if let Some(fields) = &node.fields {
230355
for field in fields {
231356
match field.c_type.as_str() {

rust/ruby-rbs/src/lib.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,113 @@ mod tests {
326326
panic!("Expected TypeAlias with RecordType");
327327
}
328328
}
329+
330+
#[test]
331+
fn visitor_test() {
332+
struct Visitor {
333+
visited: Vec<String>,
334+
}
335+
336+
impl Visit for Visitor {
337+
fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
338+
self.visited.push("type:bool".to_string());
339+
340+
crate::visit_bool_type_node(self, node);
341+
}
342+
343+
fn visit_class_node(&mut self, node: &ClassNode) {
344+
self.visited.push(format!(
345+
"class:{}",
346+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
347+
));
348+
349+
crate::visit_class_node(self, node);
350+
}
351+
352+
fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
353+
self.visited.push(format!(
354+
"type:{}",
355+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
356+
));
357+
358+
crate::visit_class_instance_type_node(self, node);
359+
}
360+
361+
fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
362+
self.visited.push(format!(
363+
"super:{}",
364+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
365+
));
366+
367+
crate::visit_class_super_node(self, node);
368+
}
369+
370+
fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
371+
let count = node.required_positionals().iter().count();
372+
self.visited
373+
.push(format!("function:required_positionals:{count}"));
374+
375+
crate::visit_function_type_node(self, node);
376+
}
377+
378+
fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
379+
self.visited.push(format!(
380+
"method:{}",
381+
String::from_utf8(node.name().name().to_vec()).unwrap()
382+
));
383+
384+
crate::visit_method_definition_node(self, node);
385+
}
386+
387+
fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
388+
self.visited.push("record".to_string());
389+
390+
crate::visit_record_type_node(self, node);
391+
}
392+
393+
fn visit_symbol_node(&mut self, node: &SymbolNode) {
394+
self.visited.push(format!(
395+
"symbol:{}",
396+
String::from_utf8(node.name().to_vec()).unwrap()
397+
));
398+
399+
crate::visit_symbol_node(self, node);
400+
}
401+
}
402+
403+
let rbs_code = r#"
404+
class Foo < Bar
405+
def process: ({ name: String, age: Integer }, bool) -> void
406+
end
407+
"#;
408+
409+
let signature = parse(rbs_code.as_bytes()).unwrap();
410+
411+
let mut visitor = Visitor {
412+
visited: Vec::new(),
413+
};
414+
415+
visitor.visit(&signature.as_node());
416+
417+
assert_eq!(
418+
vec![
419+
"class:Foo",
420+
"symbol:Foo",
421+
"super:Bar",
422+
"symbol:Bar",
423+
"method:process",
424+
"symbol:process",
425+
"function:required_positionals:2",
426+
"record",
427+
"symbol:name",
428+
"type:String",
429+
"symbol:String",
430+
"symbol:age",
431+
"type:Integer",
432+
"symbol:Integer",
433+
"type:bool",
434+
],
435+
visitor.visited
436+
);
437+
}
329438
}

0 commit comments

Comments
 (0)