Skip to content

Commit d7675f8

Browse files
authored
Generate location() accessor for each node type (#74)
Each node already has location data in its C struct, but it wasn't exposed through the Rust API. This adds a generated `location()` method to every node type, making it easy to get source ranges for any part of the AST. Also removing `parser` from location structs as it is not needed.
1 parent 8037e12 commit d7675f8

2 files changed

Lines changed: 101 additions & 14 deletions

File tree

rust/ruby-rbs/build.rs

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,16 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
350350
writeln!(file, " pub fn as_node(self) -> Node {{")?;
351351
writeln!(file, " Node::{}(self)", node.variant_name())?;
352352
writeln!(file, " }}")?;
353+
writeln!(file)?;
354+
writeln!(file, " /// Returns the location of this node.")?;
355+
writeln!(file, " #[must_use]")?;
356+
writeln!(file, " pub fn location(&self) -> RBSLocation {{")?;
357+
writeln!(
358+
file,
359+
" RBSLocation::new(unsafe {{ (*self.pointer).base.location }})"
360+
)?;
361+
writeln!(file, " }}")?;
362+
writeln!(file)?;
353363

354364
if let Some(fields) = &node.fields {
355365
for field in fields {
@@ -379,10 +389,64 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
379389
write_node_field_accessor(&mut file, field, "RBSHash")?;
380390
}
381391
"rbs_location" => {
382-
write_node_field_accessor(&mut file, field, "RBSLocation")?;
392+
if field.optional {
393+
writeln!(
394+
file,
395+
" pub fn {}(&self) -> Option<RBSLocation> {{",
396+
field.name
397+
)?;
398+
writeln!(
399+
file,
400+
" let ptr = unsafe {{ (*self.pointer).{} }};",
401+
field.c_name()
402+
)?;
403+
writeln!(file, " if ptr.is_null() {{")?;
404+
writeln!(file, " None")?;
405+
writeln!(file, " }} else {{")?;
406+
writeln!(file, " Some(RBSLocation {{ pointer: ptr }})")?;
407+
writeln!(file, " }}")?;
408+
writeln!(file, " }}")?;
409+
} else {
410+
writeln!(file, " pub fn {}(&self) -> RBSLocation {{", field.name)?;
411+
writeln!(
412+
file,
413+
" RBSLocation {{ pointer: unsafe {{ (*self.pointer).{} }} }}",
414+
field.c_name()
415+
)?;
416+
writeln!(file, " }}")?;
417+
}
383418
}
384419
"rbs_location_list" => {
385-
write_node_field_accessor(&mut file, field, "RBSLocationList")?;
420+
if field.optional {
421+
writeln!(
422+
file,
423+
" pub fn {}(&self) -> Option<RBSLocationList> {{",
424+
field.name
425+
)?;
426+
writeln!(
427+
file,
428+
" let ptr = unsafe {{ (*self.pointer).{} }};",
429+
field.c_name()
430+
)?;
431+
writeln!(file, " if ptr.is_null() {{")?;
432+
writeln!(file, " None")?;
433+
writeln!(file, " }} else {{")?;
434+
writeln!(file, " Some(RBSLocationList {{ pointer: ptr }})")?;
435+
writeln!(file, " }}")?;
436+
writeln!(file, " }}")?;
437+
} else {
438+
writeln!(
439+
file,
440+
" pub fn {}(&self) -> RBSLocationList {{",
441+
field.name
442+
)?;
443+
writeln!(
444+
file,
445+
" RBSLocationList {{ pointer: unsafe {{ (*self.pointer).{} }} }}",
446+
field.c_name()
447+
)?;
448+
writeln!(file, " }}")?;
449+
}
386450
}
387451
"rbs_namespace" => {
388452
write_node_field_accessor(&mut file, field, "NamespaceNode")?;

rust/ruby-rbs/src/lib.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,27 +134,24 @@ impl Iterator for RBSHashIter {
134134

135135
pub struct RBSLocation {
136136
pointer: *const rbs_location_t,
137-
#[allow(dead_code)]
138-
parser: *mut rbs_parser_t,
139137
}
140138

141139
impl RBSLocation {
142-
pub fn new(pointer: *const rbs_location_t, parser: *mut rbs_parser_t) -> Self {
143-
Self { pointer, parser }
140+
pub fn new(pointer: *const rbs_location_t) -> Self {
141+
Self { pointer }
144142
}
145143

146-
pub fn start_loc(&self) -> i32 {
144+
pub fn start(&self) -> i32 {
147145
unsafe { (*self.pointer).rg.start.byte_pos }
148146
}
149147

150-
pub fn end_loc(&self) -> i32 {
148+
pub fn end(&self) -> i32 {
151149
unsafe { (*self.pointer).rg.end.byte_pos }
152150
}
153151
}
154152

155153
pub struct RBSLocationListIter {
156154
current: *mut rbs_location_list_node_t,
157-
parser: *mut rbs_parser_t,
158155
}
159156

160157
impl Iterator for RBSLocationListIter {
@@ -165,7 +162,7 @@ impl Iterator for RBSLocationListIter {
165162
None
166163
} else {
167164
let pointer_data = unsafe { *self.current };
168-
let loc = RBSLocation::new(pointer_data.loc, self.parser);
165+
let loc = RBSLocation::new(pointer_data.loc);
169166
self.current = pointer_data.next;
170167
Some(loc)
171168
}
@@ -174,20 +171,18 @@ impl Iterator for RBSLocationListIter {
174171

175172
pub struct RBSLocationList {
176173
pointer: *mut rbs_location_list,
177-
parser: *mut rbs_parser_t,
178174
}
179175

180176
impl RBSLocationList {
181-
pub fn new(pointer: *mut rbs_location_list, parser: *mut rbs_parser_t) -> Self {
182-
Self { pointer, parser }
177+
pub fn new(pointer: *mut rbs_location_list) -> Self {
178+
Self { pointer }
183179
}
184180

185181
/// Returns an iterator over the locations.
186182
#[must_use]
187183
pub fn iter(&self) -> RBSLocationListIter {
188184
RBSLocationListIter {
189185
current: unsafe { (*self.pointer).head },
190-
parser: self.parser,
191186
}
192187
}
193188
}
@@ -435,4 +430,32 @@ mod tests {
435430
visitor.visited
436431
);
437432
}
433+
434+
#[test]
435+
fn test_node_location_ranges() {
436+
let rbs_code = r#"type foo = 1"#;
437+
let signature = parse(rbs_code.as_bytes()).unwrap();
438+
439+
let declaration = signature.declarations().iter().next().unwrap();
440+
let Node::TypeAlias(type_alias) = declaration else {
441+
panic!("Expected TypeAlias");
442+
};
443+
444+
// TypeAlias spans the entire declaration
445+
let loc = type_alias.location();
446+
assert_eq!(0, loc.start());
447+
assert_eq!(12, loc.end());
448+
449+
// The literal "1" is at position 11-12
450+
let Node::LiteralType(literal) = type_alias.type_() else {
451+
panic!("Expected LiteralType");
452+
};
453+
let Node::Integer(integer) = literal.literal() else {
454+
panic!("Expected Integer");
455+
};
456+
457+
let int_loc = integer.location();
458+
assert_eq!(11, int_loc.start());
459+
assert_eq!(12, int_loc.end());
460+
}
438461
}

0 commit comments

Comments
 (0)