Skip to content

Commit 5ceabb3

Browse files
committed
Generate location() accessor for each node type
Each node already has location data in its C struct, but it wasn't exposed through the Rust API. This adds a generated `location()` method to every node type, making it easy to get source ranges for any part of the AST. Also reorders RBSLocation and RBSLocationList struct fields to consistently put parser first, matching the pattern used elsewhere.
1 parent 8037e12 commit 5ceabb3

2 files changed

Lines changed: 47 additions & 9 deletions

File tree

rust/ruby-rbs/build.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,16 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
350350
writeln!(file, " pub fn as_node(self) -> Node {{")?;
351351
writeln!(file, " Node::{}(self)", node.variant_name())?;
352352
writeln!(file, " }}")?;
353+
writeln!(file)?;
354+
writeln!(file, " /// Returns the location of this node.")?;
355+
writeln!(file, " #[must_use]")?;
356+
writeln!(file, " pub fn location(&self) -> RBSLocation {{")?;
357+
writeln!(
358+
file,
359+
" RBSLocation::new(self.parser, unsafe {{ (*self.pointer).base.location }})"
360+
)?;
361+
writeln!(file, " }}")?;
362+
writeln!(file)?;
353363

354364
if let Some(fields) = &node.fields {
355365
for field in fields {

rust/ruby-rbs/src/lib.rs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ impl Iterator for RBSHashIter {
133133
}
134134

135135
pub struct RBSLocation {
136-
pointer: *const rbs_location_t,
137136
#[allow(dead_code)]
138137
parser: *mut rbs_parser_t,
138+
pointer: *const rbs_location_t,
139139
}
140140

141141
impl RBSLocation {
142-
pub fn new(pointer: *const rbs_location_t, parser: *mut rbs_parser_t) -> Self {
143-
Self { pointer, parser }
142+
pub fn new(parser: *mut rbs_parser_t, pointer: *const rbs_location_t) -> Self {
143+
Self { parser, pointer }
144144
}
145145

146146
pub fn start_loc(&self) -> i32 {
@@ -153,8 +153,8 @@ impl RBSLocation {
153153
}
154154

155155
pub struct RBSLocationListIter {
156-
current: *mut rbs_location_list_node_t,
157156
parser: *mut rbs_parser_t,
157+
current: *mut rbs_location_list_node_t,
158158
}
159159

160160
impl Iterator for RBSLocationListIter {
@@ -165,29 +165,29 @@ impl Iterator for RBSLocationListIter {
165165
None
166166
} else {
167167
let pointer_data = unsafe { *self.current };
168-
let loc = RBSLocation::new(pointer_data.loc, self.parser);
168+
let loc = RBSLocation::new(self.parser, pointer_data.loc);
169169
self.current = pointer_data.next;
170170
Some(loc)
171171
}
172172
}
173173
}
174174

175175
pub struct RBSLocationList {
176-
pointer: *mut rbs_location_list,
177176
parser: *mut rbs_parser_t,
177+
pointer: *mut rbs_location_list,
178178
}
179179

180180
impl RBSLocationList {
181-
pub fn new(pointer: *mut rbs_location_list, parser: *mut rbs_parser_t) -> Self {
182-
Self { pointer, parser }
181+
pub fn new(parser: *mut rbs_parser_t, pointer: *mut rbs_location_list) -> Self {
182+
Self { parser, pointer }
183183
}
184184

185185
/// Returns an iterator over the locations.
186186
#[must_use]
187187
pub fn iter(&self) -> RBSLocationListIter {
188188
RBSLocationListIter {
189-
current: unsafe { (*self.pointer).head },
190189
parser: self.parser,
190+
current: unsafe { (*self.pointer).head },
191191
}
192192
}
193193
}
@@ -435,4 +435,32 @@ mod tests {
435435
visitor.visited
436436
);
437437
}
438+
439+
#[test]
440+
fn test_node_location_ranges() {
441+
let rbs_code = r#"type foo = 1"#;
442+
let signature = parse(rbs_code.as_bytes()).unwrap();
443+
444+
let declaration = signature.declarations().iter().next().unwrap();
445+
let Node::TypeAlias(type_alias) = declaration else {
446+
panic!("Expected TypeAlias");
447+
};
448+
449+
// TypeAlias spans the entire declaration
450+
let loc = type_alias.location();
451+
assert_eq!(0, loc.start_loc());
452+
assert_eq!(12, loc.end_loc());
453+
454+
// The literal "1" is at position 11-12
455+
let Node::LiteralType(literal) = type_alias.type_() else {
456+
panic!("Expected LiteralType");
457+
};
458+
let Node::Integer(integer) = literal.literal() else {
459+
panic!("Expected Integer");
460+
};
461+
462+
let int_loc = integer.location();
463+
assert_eq!(11, int_loc.start_loc());
464+
assert_eq!(12, int_loc.end_loc());
465+
}
438466
}

0 commit comments

Comments
 (0)