Skip to content

Commit 93ecfd6

Browse files
authored
Merge pull request #2917 from ruby/rust-api
Better Rust API
2 parents a1fab76 + 4d1c454 commit 93ecfd6

File tree

3 files changed

+48
-43
lines changed

3 files changed

+48
-43
lines changed

rust/ruby-rbs/examples/locations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use ruby_rbs::node::{Node, parse};
22

33
fn main() {
44
let rbs_code = r#"class Foo[T] < Bar end"#;
5-
let signature = parse(rbs_code.as_bytes()).unwrap();
5+
let signature = parse(rbs_code).unwrap();
66

77
let declaration = signature.declarations().iter().next().unwrap();
88
if let Node::Class(class) = declaration {

rust/ruby-rbs/src/node/mod.rs

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ use std::ptr::NonNull;
99
/// ```rust
1010
/// use ruby_rbs::node::parse;
1111
/// let rbs_code = r#"type foo = "hello""#;
12-
/// let signature = parse(rbs_code.as_bytes());
12+
/// let signature = parse(rbs_code);
1313
/// assert!(signature.is_ok(), "Failed to parse RBS signature");
1414
/// ```
15-
pub fn parse(rbs_code: &[u8]) -> Result<SignatureNode<'_>, String> {
15+
pub fn parse(rbs_code: &str) -> Result<SignatureNode<'_>, String> {
1616
unsafe {
1717
let start_ptr = rbs_code.as_ptr().cast::<std::os::raw::c_char>();
1818
let end_ptr = start_ptr.add(rbs_code.len());
@@ -253,17 +253,29 @@ impl<'a> RBSString<'a> {
253253
}
254254

255255
#[must_use]
256+
#[allow(clippy::unnecessary_cast)]
256257
pub fn as_bytes(&self) -> &[u8] {
257258
unsafe {
258259
let s = *self.pointer;
259260
std::slice::from_raw_parts(s.start as *const u8, s.end.offset_from(s.start) as usize)
260261
}
261262
}
263+
264+
#[must_use]
265+
pub fn as_str(&self) -> &str {
266+
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
267+
}
268+
}
269+
270+
impl std::fmt::Display for RBSString<'_> {
271+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272+
f.write_str(self.as_str())
273+
}
262274
}
263275

264276
impl SymbolNode<'_> {
265277
#[must_use]
266-
pub fn name(&self) -> &[u8] {
278+
pub fn as_bytes(&self) -> &[u8] {
267279
unsafe {
268280
let constant_ptr = rbs_constant_pool_id_to_constant(
269281
&(*self.parser.as_ptr()).constant_pool,
@@ -277,6 +289,17 @@ impl SymbolNode<'_> {
277289
std::slice::from_raw_parts(constant.start, constant.length)
278290
}
279291
}
292+
293+
#[must_use]
294+
pub fn as_str(&self) -> &str {
295+
unsafe { std::str::from_utf8_unchecked(self.as_bytes()) }
296+
}
297+
}
298+
299+
impl std::fmt::Display for SymbolNode<'_> {
300+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301+
f.write_str(self.as_str())
302+
}
280303
}
281304

282305
#[cfg(test)]
@@ -286,37 +309,34 @@ mod tests {
286309
#[test]
287310
fn test_parse_error_contains_actual_message() {
288311
let rbs_code = "class { end";
289-
let result = parse(rbs_code.as_bytes());
312+
let result = parse(rbs_code);
290313
let error_message = result.unwrap_err();
291314
assert_eq!(error_message, "expected one of class/module/constant name");
292315
}
293316

294317
#[test]
295318
fn test_parse() {
296319
let rbs_code = r#"type foo = "hello""#;
297-
let signature = parse(rbs_code.as_bytes());
320+
let signature = parse(rbs_code);
298321
assert!(signature.is_ok(), "Failed to parse RBS signature");
299322

300323
let rbs_code2 = r#"class Foo end"#;
301-
let signature2 = parse(rbs_code2.as_bytes());
324+
let signature2 = parse(rbs_code2);
302325
assert!(signature2.is_ok(), "Failed to parse RBS signature");
303326
}
304327

305328
#[test]
306329
fn test_parse_integer() {
307330
let rbs_code = r#"type foo = 1"#;
308-
let signature = parse(rbs_code.as_bytes());
331+
let signature = parse(rbs_code);
309332
assert!(signature.is_ok(), "Failed to parse RBS signature");
310333

311334
let signature_node = signature.unwrap();
312335
if let Node::TypeAlias(node) = signature_node.declarations().iter().next().unwrap()
313336
&& let Node::LiteralType(literal) = node.type_()
314337
&& let Node::Integer(integer) = literal.literal()
315338
{
316-
assert_eq!(
317-
"1".to_string(),
318-
String::from_utf8(integer.string_representation().as_bytes().to_vec()).unwrap()
319-
);
339+
assert_eq!(integer.string_representation().as_str(), "1");
320340
} else {
321341
panic!("No literal type node found");
322342
}
@@ -326,7 +346,7 @@ mod tests {
326346
fn test_rbs_hash_via_record_type() {
327347
// RecordType stores its fields in an RBSHash via all_fields()
328348
let rbs_code = r#"type foo = { name: String, age: Integer }"#;
329-
let signature = parse(rbs_code.as_bytes());
349+
let signature = parse(rbs_code);
330350
assert!(signature.is_ok(), "Failed to parse RBS signature");
331351

332352
let signature_node = signature.unwrap();
@@ -350,10 +370,10 @@ mod tests {
350370
panic!("Expected ClassInstanceType");
351371
};
352372

353-
let key_name = String::from_utf8(sym.name().to_vec()).unwrap();
373+
let key_name = sym.to_string();
354374
let type_name_node = class_type.name();
355375
let type_name_sym = type_name_node.name();
356-
let type_name = String::from_utf8(type_name_sym.name().to_vec()).unwrap();
376+
let type_name = type_name_sym.to_string();
357377
field_types.push((key_name, type_name));
358378
}
359379

@@ -384,28 +404,19 @@ mod tests {
384404
}
385405

386406
fn visit_class_node(&mut self, node: &ClassNode) {
387-
self.visited.push(format!(
388-
"class:{}",
389-
String::from_utf8(node.name().name().name().to_vec()).unwrap()
390-
));
407+
self.visited.push(format!("class:{}", node.name().name()));
391408

392409
crate::node::visit_class_node(self, node);
393410
}
394411

395412
fn visit_class_instance_type_node(&mut self, node: &ClassInstanceTypeNode) {
396-
self.visited.push(format!(
397-
"type:{}",
398-
String::from_utf8(node.name().name().name().to_vec()).unwrap()
399-
));
413+
self.visited.push(format!("type:{}", node.name().name()));
400414

401415
crate::node::visit_class_instance_type_node(self, node);
402416
}
403417

404418
fn visit_class_super_node(&mut self, node: &ClassSuperNode) {
405-
self.visited.push(format!(
406-
"super:{}",
407-
String::from_utf8(node.name().name().name().to_vec()).unwrap()
408-
));
419+
self.visited.push(format!("super:{}", node.name().name()));
409420

410421
crate::node::visit_class_super_node(self, node);
411422
}
@@ -419,10 +430,7 @@ mod tests {
419430
}
420431

421432
fn visit_method_definition_node(&mut self, node: &MethodDefinitionNode) {
422-
self.visited.push(format!(
423-
"method:{}",
424-
String::from_utf8(node.name().name().to_vec()).unwrap()
425-
));
433+
self.visited.push(format!("method:{}", node.name()));
426434

427435
crate::node::visit_method_definition_node(self, node);
428436
}
@@ -434,10 +442,7 @@ mod tests {
434442
}
435443

436444
fn visit_symbol_node(&mut self, node: &SymbolNode) {
437-
self.visited.push(format!(
438-
"symbol:{}",
439-
String::from_utf8(node.name().to_vec()).unwrap()
440-
));
445+
self.visited.push(format!("symbol:{node}"));
441446

442447
crate::node::visit_symbol_node(self, node);
443448
}
@@ -449,7 +454,7 @@ mod tests {
449454
end
450455
"#;
451456

452-
let signature = parse(rbs_code.as_bytes()).unwrap();
457+
let signature = parse(rbs_code).unwrap();
453458

454459
let mut visitor = Visitor {
455460
visited: Vec::new(),
@@ -482,7 +487,7 @@ mod tests {
482487
#[test]
483488
fn test_node_location_ranges() {
484489
let rbs_code = r#"type foo = 1"#;
485-
let signature = parse(rbs_code.as_bytes()).unwrap();
490+
let signature = parse(rbs_code).unwrap();
486491

487492
let declaration = signature.declarations().iter().next().unwrap();
488493
let Node::TypeAlias(type_alias) = declaration else {
@@ -510,7 +515,7 @@ mod tests {
510515
#[test]
511516
fn test_sub_locations() {
512517
let rbs_code = r#"class Foo < Bar end"#;
513-
let signature = parse(rbs_code.as_bytes()).unwrap();
518+
let signature = parse(rbs_code).unwrap();
514519

515520
let declaration = signature.declarations().iter().next().unwrap();
516521
let Node::Class(class) = declaration else {
@@ -545,7 +550,7 @@ mod tests {
545550
#[test]
546551
fn test_type_alias_sub_locations() {
547552
let rbs_code = r#"type foo = String"#;
548-
let signature = parse(rbs_code.as_bytes()).unwrap();
553+
let signature = parse(rbs_code).unwrap();
549554

550555
let declaration = signature.declarations().iter().next().unwrap();
551556
let Node::TypeAlias(type_alias) = declaration else {
@@ -573,7 +578,7 @@ mod tests {
573578
#[test]
574579
fn test_module_sub_locations() {
575580
let rbs_code = r#"module Foo[T] : Bar end"#;
576-
let signature = parse(rbs_code.as_bytes()).unwrap();
581+
let signature = parse(rbs_code).unwrap();
577582

578583
let declaration = signature.declarations().iter().next().unwrap();
579584
let Node::Module(module) = declaration else {
@@ -626,7 +631,7 @@ mod tests {
626631
class Bar[out T, in U, V]
627632
end
628633
"#;
629-
let signature = parse(rbs_code.as_bytes()).unwrap();
634+
let signature = parse(rbs_code).unwrap();
630635

631636
let declarations: Vec<_> = signature.declarations().iter().collect();
632637

@@ -706,7 +711,7 @@ mod tests {
706711
attr_writer email(@email): String
707712
end
708713
"#;
709-
let signature = parse(rbs_code.as_bytes()).unwrap();
714+
let signature = parse(rbs_code).unwrap();
710715

711716
let Node::Class(class) = signature.declarations().iter().next().unwrap() else {
712717
panic!("Expected Class");

rust/ruby-rbs/tests/sanity.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fn all_included_rbs_can_be_parsed() {
3333
for file in &files {
3434
let content = std::fs::read_to_string(file).unwrap();
3535

36-
if let Err(e) = parse(content.as_bytes()) {
36+
if let Err(e) = parse(&content) {
3737
failures.push(format!("{}: {}", file.display(), e));
3838
}
3939
}

0 commit comments

Comments
 (0)