Skip to content

Commit 1849cbe

Browse files
committed
Merge pull request #2457 from Shopify/at-parse-type-params
Expose a method to parse type parameters
1 parent 1bdaa8b commit 1849cbe

6 files changed

Lines changed: 151 additions & 0 deletions

File tree

ext/rbs_extension/main.c

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,55 @@ static VALUE rbsparser_parse_signature(VALUE self, VALUE buffer, VALUE start_pos
267267
return result;
268268
}
269269

270+
struct parse_type_params_arg {
271+
VALUE buffer;
272+
rb_encoding *encoding;
273+
rbs_parser_t *parser;
274+
VALUE module_type_params;
275+
};
276+
277+
static VALUE parse_type_params_try(VALUE a) {
278+
struct parse_type_params_arg *arg = (struct parse_type_params_arg *) a;
279+
rbs_parser_t *parser = arg->parser;
280+
281+
if (parser->next_token.type == pEOF) {
282+
return Qnil;
283+
}
284+
285+
rbs_node_list_t *params = NULL;
286+
rbs_parse_type_params(parser, arg->module_type_params, &params);
287+
288+
raise_error_if_any(parser, arg->buffer);
289+
290+
rbs_translation_context_t ctx = rbs_translation_context_create(
291+
&parser->constant_pool,
292+
arg->buffer,
293+
arg->encoding
294+
);
295+
296+
return rbs_node_list_to_ruby_array(ctx, params);
297+
}
298+
299+
static VALUE rbsparser_parse_type_params(VALUE self, VALUE buffer, VALUE start_pos, VALUE end_pos, VALUE module_type_params) {
300+
VALUE string = rb_funcall(buffer, rb_intern("content"), 0);
301+
StringValue(string);
302+
rb_encoding *encoding = rb_enc_get(string);
303+
304+
rbs_parser_t *parser = alloc_parser_from_buffer(buffer, FIX2INT(start_pos), FIX2INT(end_pos));
305+
struct parse_type_params_arg arg = {
306+
.buffer = buffer,
307+
.encoding = encoding,
308+
.parser = parser,
309+
.module_type_params = module_type_params
310+
};
311+
312+
VALUE result = rb_ensure(parse_type_params_try, (VALUE) &arg, ensure_free_parser, (VALUE) parser);
313+
314+
RB_GC_GUARD(string);
315+
316+
return result;
317+
}
318+
270319
static VALUE rbsparser_lex(VALUE self, VALUE buffer, VALUE end_pos) {
271320
VALUE string = rb_funcall(buffer, rb_intern("content"), 0);
272321
StringValue(string);
@@ -304,6 +353,7 @@ void rbs__init_parser(void) {
304353
rb_define_singleton_method(RBS_Parser, "_parse_type", rbsparser_parse_type, 5);
305354
rb_define_singleton_method(RBS_Parser, "_parse_method_type", rbsparser_parse_method_type, 5);
306355
rb_define_singleton_method(RBS_Parser, "_parse_signature", rbsparser_parse_signature, 3);
356+
rb_define_singleton_method(RBS_Parser, "_parse_type_params", rbsparser_parse_type_params, 4);
307357
rb_define_singleton_method(RBS_Parser, "_lex", rbsparser_lex, 2);
308358
}
309359

include/rbs/parser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,6 @@ bool rbs_parse_type(rbs_parser_t *parser, rbs_node_t **type);
130130
bool rbs_parse_method_type(rbs_parser_t *parser, rbs_method_type_t **method_type);
131131
bool rbs_parse_signature(rbs_parser_t *parser, rbs_signature_t **signature);
132132

133+
bool rbs_parse_type_params(rbs_parser_t *parser, bool module_type_params, rbs_node_list_t **params);
134+
133135
#endif

lib/rbs/parser_aux.rb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def self.parse_signature(source)
3535
[buf, dirs, decls]
3636
end
3737

38+
def self.parse_type_params(source, module_type_params: true)
39+
buf = buffer(source)
40+
_parse_type_params(buf, 0, buf.last_position, module_type_params)
41+
end
42+
3843
def self.magic_comment(buf)
3944
start_pos = 0
4045

sig/parser.rbs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ module RBS
6868
#
6969
def self.parse_signature: (Buffer | String) -> [Buffer, Array[AST::Directives::t], Array[AST::Declarations::t]]
7070

71+
# Parse a list of type parameters and return it
72+
#
73+
# ```ruby
74+
# RBS::Parser.parse_type_params("") # => nil
75+
# RBS::Parser.parse_type_params("[U, V]") # => `[:U, :V]`
76+
# RBS::Parser.parse_type_params("[in U, V < Integer]") # => `[:U, :V]`
77+
# ```
78+
#
79+
# When `module_type_params` is `false`, an error is raised if `unchecked`, `in` or `out` are used.
80+
#
81+
# ```ruby
82+
# RBS::Parser.parse_type_params("[unchecked U]", module_type_params: false) # => Raises an error
83+
# RBS::Parser.parse_type_params("[out U]", module_type_params: false) # => Raises an error
84+
# RBS::Parser.parse_type_params("[in U]", module_type_params: false) # => Raises an error
85+
# ```
86+
#
87+
def self.parse_type_params: (Buffer | String, ?module_type_params: bool) -> Array[AST::TypeParam]
88+
7189
# Returns the magic comment from the buffer
7290
#
7391
def self.magic_comment: (Buffer) -> AST::Directives::ResolveTypeNames?
@@ -92,6 +110,8 @@ module RBS
92110

93111
def self._parse_signature: (Buffer, Integer start_pos, Integer end_pos) -> [Array[AST::Directives::t], Array[AST::Declarations::t]]
94112

113+
def self._parse_type_params: (Buffer, Integer start_pos, Integer end_pos, bool module_type_params) -> Array[AST::TypeParam]
114+
95115
def self._lex: (Buffer, Integer end_pos) -> Array[[Symbol, Location[untyped, untyped]]]
96116

97117
class LocatedValue

src/parser.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3228,6 +3228,26 @@ bool rbs_parse_signature(rbs_parser_t *parser, rbs_signature_t **signature) {
32283228
return true;
32293229
}
32303230

3231+
bool rbs_parse_type_params(rbs_parser_t *parser, bool module_type_params, rbs_node_list_t **params) {
3232+
if (parser->next_token.type != pLBRACKET) {
3233+
rbs_parser_set_error(parser, parser->next_token, true, "expected a token `pLBRACKET`");
3234+
return false;
3235+
}
3236+
3237+
rbs_range_t rg = NULL_RANGE;
3238+
rbs_parser_push_typevar_table(parser, true);
3239+
bool res = parse_type_params(parser, &rg, module_type_params, params);
3240+
rbs_parser_push_typevar_table(parser, false);
3241+
3242+
rbs_parser_advance(parser);
3243+
if (parser->current_token.type != pEOF) {
3244+
rbs_parser_set_error(parser, parser->current_token, true, "expected a token `%s`", rbs_token_type_str(pEOF));
3245+
return false;
3246+
}
3247+
3248+
return res;
3249+
}
3250+
32313251
id_table *alloc_empty_table(rbs_allocator_t *allocator) {
32323252
id_table *table = rbs_allocator_alloc(allocator, id_table);
32333253

test/rbs/parser_test.rb

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,60 @@ def test_proc__untyped_function
820820
end
821821
end
822822

823+
def test_parse_type_params
824+
RBS::Parser.parse_type_params(buffer("[T]")).tap do |params|
825+
assert_equal 1, params.size
826+
assert_equal :T, params[0].name
827+
assert_nil params[0].upper_bound
828+
end
829+
830+
RBS::Parser.parse_type_params(buffer("[T < Integer, U = String]")).tap do |params|
831+
assert_equal 2, params.size
832+
assert_equal :T, params[0].name
833+
assert_equal "Integer", params[0].upper_bound.to_s
834+
assert_equal :U, params[1].name
835+
assert_equal "String", params[1].default_type.to_s
836+
end
837+
838+
RBS::Parser.parse_type_params(buffer("[T, in U, out V]")).tap do |params|
839+
assert_equal 3, params.size
840+
assert_equal :T, params[0].name
841+
assert_equal "invariant", params[0].variance.to_s
842+
assert_equal :U, params[1].name
843+
assert_equal "contravariant", params[1].variance.to_s
844+
assert_equal :V, params[2].name
845+
assert_equal "covariant", params[2].variance.to_s
846+
end
847+
848+
RBS::Parser.parse_type_params(buffer("[T, unchecked U, unchecked out V = Integer]")).tap do |params|
849+
assert_equal 3, params.size
850+
assert_equal :T, params[0].name
851+
refute params[0].unchecked?
852+
assert_equal :U, params[1].name
853+
assert params[1].unchecked?
854+
assert_equal :V, params[2].name
855+
assert params[2].unchecked?
856+
assert_equal "covariant", params[2].variance.to_s
857+
assert_equal "Integer", params[2].default_type.to_s
858+
end
859+
860+
assert_raises RBS::ParsingError do
861+
RBS::Parser.parse_type_params(buffer("[]"))
862+
end
863+
864+
assert_raises RBS::ParsingError do
865+
RBS::Parser.parse_type_params(buffer("[T]A"))
866+
end
867+
868+
assert_raises RBS::ParsingError do
869+
RBS::Parser.parse_type_params(buffer("[in T]"), module_type_params: false)
870+
end
871+
872+
assert_raises RBS::ParsingError do
873+
RBS::Parser.parse_type_params(buffer("[unchecked T]"), module_type_params: false)
874+
end
875+
end
876+
823877
def test__lex
824878
content = <<~RBS
825879
# LineComment

0 commit comments

Comments
 (0)