Skip to content

Commit dcb165c

Browse files
committed
Generate child node traversal in visitor functions
1 parent daffb8e commit dcb165c

2 files changed

Lines changed: 227 additions & 1 deletion

File tree

rust/ruby-rbs/build.rs

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,127 @@ fn write_visit_trait(file: &mut File, config: &Config) -> Result<(), Box<dyn std
185185
writeln!(file, "}}")?;
186186
writeln!(file)?;
187187

188+
// Map C field types (e.g. `rbs_type_name`) to the corresponding
189+
// visitor method name (e.g. `type_name` -> `visit_type_name_node`).
190+
let visitor_method_names: std::collections::HashMap<String, String> = config
191+
.nodes
192+
.iter()
193+
.map(|node| {
194+
let c_type = convert_name(&node.name, CIdentifier::Type);
195+
let c_type = c_type.strip_suffix("_t").unwrap_or(&c_type).to_string();
196+
let method = convert_name(node.variant_name(), CIdentifier::Method);
197+
(c_type, method)
198+
})
199+
.collect();
200+
201+
let is_visitable = |c_type: &str| -> bool {
202+
matches!(c_type, "rbs_node" | "rbs_node_list" | "rbs_hash")
203+
|| visitor_method_names.contains_key(c_type)
204+
};
205+
188206
for node in &config.nodes {
189207
let node_variant_name = node.variant_name();
190208
let method_name = convert_name(node_variant_name, CIdentifier::Method);
191209

192-
writeln!(file, "#[allow(unused_variables)]")?; // TODO: Remove this once all nodes that need visitor are implemented
210+
let has_visitable_fields = node
211+
.fields
212+
.iter()
213+
.flatten()
214+
.any(|field| is_visitable(&field.c_type));
215+
216+
if !has_visitable_fields {
217+
writeln!(file, "#[allow(unused_variables)]")?;
218+
}
193219
writeln!(
194220
file,
195221
"pub fn visit_{}_node<V: Visit + ?Sized>(visitor: &mut V, node: &{}Node) {{",
196222
method_name, node_variant_name
197223
)?;
224+
225+
if let Some(fields) = &node.fields {
226+
for field in fields {
227+
let field_method_name = if field.name == "type" {
228+
"type_"
229+
} else {
230+
field.name.as_str()
231+
};
232+
233+
match field.c_type.as_str() {
234+
"rbs_node" => {
235+
if field.optional {
236+
writeln!(
237+
file,
238+
" if let Some(item) = node.{field_method_name}() {{"
239+
)?;
240+
writeln!(file, " visitor.visit(&item);")?;
241+
writeln!(file, " }}")?;
242+
} else {
243+
writeln!(file, " visitor.visit(&node.{field_method_name}());")?;
244+
}
245+
}
246+
247+
"rbs_node_list" => {
248+
if field.optional {
249+
writeln!(
250+
file,
251+
" if let Some(list) = node.{field_method_name}() {{"
252+
)?;
253+
writeln!(file, " for item in list.iter() {{")?;
254+
writeln!(file, " visitor.visit(&item);")?;
255+
writeln!(file, " }}")?;
256+
writeln!(file, " }}")?;
257+
} else {
258+
writeln!(file, " for item in node.{field_method_name}().iter() {{")?;
259+
writeln!(file, " visitor.visit(&item);")?;
260+
writeln!(file, " }}")?;
261+
}
262+
}
263+
264+
"rbs_hash" => {
265+
if field.optional {
266+
writeln!(
267+
file,
268+
" if let Some(hash) = node.{field_method_name}() {{"
269+
)?;
270+
writeln!(file, " for (key, value) in hash.iter() {{")?;
271+
writeln!(file, " visitor.visit(&key);")?;
272+
writeln!(file, " visitor.visit(&value);")?;
273+
writeln!(file, " }}")?;
274+
writeln!(file, " }}")?;
275+
} else {
276+
writeln!(
277+
file,
278+
" for (key, value) in node.{field_method_name}().iter() {{"
279+
)?;
280+
writeln!(file, " visitor.visit(&key);")?;
281+
writeln!(file, " visitor.visit(&value);")?;
282+
writeln!(file, " }}")?;
283+
}
284+
}
285+
286+
_ => {
287+
if let Some(visit_method_name) = visitor_method_names.get(&field.c_type) {
288+
if field.optional {
289+
writeln!(
290+
file,
291+
" if let Some(item) = node.{field_method_name}() {{"
292+
)?;
293+
writeln!(
294+
file,
295+
" visitor.visit_{visit_method_name}_node(&item);"
296+
)?;
297+
writeln!(file, " }}")?;
298+
} else {
299+
writeln!(
300+
file,
301+
" visitor.visit_{visit_method_name}_node(&node.{field_method_name}());"
302+
)?;
303+
}
304+
}
305+
}
306+
}
307+
}
308+
}
198309
writeln!(file, "}}")?;
199310
writeln!(file)?;
200311
}
@@ -226,6 +337,12 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
226337
writeln!(file, "}}\n")?;
227338

228339
writeln!(file, "impl {} {{", node.rust_name)?;
340+
writeln!(file, " /// Converts this node to a generic node.")?;
341+
writeln!(file, " #[must_use]")?;
342+
writeln!(file, " pub fn as_node(self) -> Node {{")?;
343+
writeln!(file, " Node::{}(self)", node.variant_name())?;
344+
writeln!(file, " }}")?;
345+
229346
if let Some(fields) = &node.fields {
230347
for field in fields {
231348
match field.c_type.as_str() {

rust/ruby-rbs/src/lib.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,113 @@ mod tests {
326326
panic!("Expected TypeAlias with RecordType");
327327
}
328328
}
329+
330+
#[test]
331+
fn visitor_test() {
332+
struct Visitor {
333+
visited: Vec<String>,
334+
}
335+
336+
impl Visit for Visitor {
337+
fn visit_bool_type_node(&mut self, node: &BoolTypeNode) {
338+
self.visited.push("type:bool".to_string());
339+
340+
crate::visit_bool_type_node(self, node);
341+
}
342+
343+
fn visit_class_node(&mut self, node: &ClassNode) {
344+
self.visited.push(format!(
345+
"class:{}",
346+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
347+
));
348+
349+
crate::visit_class_node(self, node);
350+
}
351+
352+
fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
353+
self.visited.push(format!(
354+
"type:{}",
355+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
356+
));
357+
358+
crate::visit_class_instance_type_node(self, node);
359+
}
360+
361+
fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
362+
self.visited.push(format!(
363+
"super:{}",
364+
String::from_utf8(node.name().name().name().to_vec()).unwrap()
365+
));
366+
367+
crate::visit_class_super_node(self, node);
368+
}
369+
370+
fn visit_function_type_node(&mut self, node: &FunctionTypeNode) {
371+
let count = node.required_positionals().iter().count();
372+
self.visited
373+
.push(format!("function:required_positionals:{count}"));
374+
375+
crate::visit_function_type_node(self, node);
376+
}
377+
378+
fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
379+
self.visited.push(format!(
380+
"method:{}",
381+
String::from_utf8(node.name().name().to_vec()).unwrap()
382+
));
383+
384+
crate::visit_method_definition_node(self, node);
385+
}
386+
387+
fn visit_record_type_node(&mut self, node: &RecordTypeNode) {
388+
self.visited.push("record".to_string());
389+
390+
crate::visit_record_type_node(self, node);
391+
}
392+
393+
fn visit_symbol_node(&mut self, node: &SymbolNode) {
394+
self.visited.push(format!(
395+
"symbol:{}",
396+
String::from_utf8(node.name().to_vec()).unwrap()
397+
));
398+
399+
crate::visit_symbol_node(self, node);
400+
}
401+
}
402+
403+
let rbs_code = r#"
404+
class Foo < Bar
405+
def process: ({ name: String, age: Integer }, bool) -> void
406+
end
407+
"#;
408+
409+
let signature = parse(rbs_code.as_bytes()).unwrap();
410+
411+
let mut visitor = Visitor {
412+
visited: Vec::new(),
413+
};
414+
415+
visitor.visit(&signature.as_node());
416+
417+
assert_eq!(
418+
vec![
419+
"class:Foo",
420+
"symbol:Foo",
421+
"super:Bar",
422+
"symbol:Bar",
423+
"method:process",
424+
"symbol:process",
425+
"function:required_positionals:2",
426+
"record",
427+
"symbol:name",
428+
"type:String",
429+
"symbol:String",
430+
"symbol:age",
431+
"type:Integer",
432+
"symbol:Integer",
433+
"type:bool",
434+
],
435+
visitor.visited
436+
);
437+
}
329438
}

0 commit comments

Comments
 (0)