@@ -24,6 +24,7 @@ class DocStyle
2424 attr_reader :required_keywords
2525 attr_reader :optional_keywords
2626 attr_accessor :rest_keywords
27+ attr_accessor :block
2728
2829 def initialize
2930 @return_type_annotation = nil
@@ -34,6 +35,7 @@ def initialize
3435 @required_keywords = { }
3536 @optional_keywords = { }
3637 @rest_keywords = nil
38+ @block = nil
3739 end
3840
3941 def self . build ( param_type_annotations , return_type_annotation , node )
@@ -42,8 +44,9 @@ def self.build(param_type_annotations, return_type_annotation, node)
4244
4345 splat_annotation = nil #: Annotations::SplatParamTypeAnnotation?
4446 double_splat_annotation = nil #: Annotations::DoubleSplatParamTypeAnnotation?
47+ block_annotation = nil #: Annotations::BlockParamTypeAnnotation?
4548 param_annotations = { } #: Hash[Symbol, Annotations::ParamTypeAnnotation]
46- unused = [ ] #: Array[Annotations::ParamTypeAnnotation | Annotations::SplatParamTypeAnnotation | Annotations::DoubleSplatParamTypeAnnotation]
49+ unused = [ ] #: Array[Annotations::ParamTypeAnnotation | Annotations::SplatParamTypeAnnotation | Annotations::DoubleSplatParamTypeAnnotation | Annotations::BlockParamTypeAnnotation ]
4750
4851 param_type_annotations . each do |annot |
4952 case annot
@@ -59,6 +62,12 @@ def self.build(param_type_annotations, return_type_annotation, node)
5962 else
6063 double_splat_annotation = annot
6164 end
65+ when Annotations ::BlockParamTypeAnnotation
66+ if block_annotation
67+ unused << annot
68+ else
69+ block_annotation = annot
70+ end
6271 when Annotations ::ParamTypeAnnotation
6372 name = annot . name_location . source . to_sym
6473 if param_annotations . key? ( name )
@@ -141,11 +150,30 @@ def self.build(param_type_annotations, return_type_annotation, node)
141150 doc . rest_keywords = kw_rest . name || true
142151 end
143152 end
153+
154+ if ( blk = params . block ) && blk . is_a? ( Prism ::BlockParameterNode )
155+ if block_annotation && ( block_annotation . name_location . nil? || blk . name . nil? || block_annotation . name == blk . name )
156+ doc . block = block_annotation
157+ block_annotation = nil
158+ else
159+ doc . block = blk . name || true
160+ end
161+ end
162+ end
163+
164+ if block_annotation
165+ if node . parameters &.block
166+ # Block parameter exists but name didn't match -- treat as unused
167+ else
168+ doc . block = block_annotation
169+ block_annotation = nil
170+ end
144171 end
145172
146173 unused . concat ( param_annotations . values )
147174 unused << splat_annotation if splat_annotation
148175 unused << double_splat_annotation if double_splat_annotation
176+ unused << block_annotation if block_annotation
149177
150178 [ doc , unused ]
151179 end
@@ -160,6 +188,7 @@ def all_param_annotations
160188 required_keywords . each_value { |a | annotations << a }
161189 optional_keywords . each_value { |a | annotations << a }
162190 annotations << rest_keywords
191+ annotations << block
163192
164193 annotations
165194 end
@@ -231,6 +260,13 @@ def map_type_name(&block)
231260 else
232261 rest_keywords
233262 end
263+ new . block =
264+ case self . block
265+ when Annotations ::BlockParamTypeAnnotation
266+ self . block . map_type_name ( &block )
267+ else
268+ self . block
269+ end
234270 end #: self
235271 end
236272
@@ -337,10 +373,26 @@ def method_type
337373 return_type : return_type
338374 )
339375
376+ method_block =
377+ case self . block
378+ when Annotations ::BlockParamTypeAnnotation
379+ Types ::Block . new (
380+ type : self . block . type ,
381+ required : self . block . required?
382+ )
383+ when Symbol , true
384+ Types ::Block . new (
385+ type : Types ::UntypedFunction . new ( return_type : Types ::Bases ::Any . new ( location : nil ) ) ,
386+ required : false
387+ )
388+ else
389+ nil
390+ end
391+
340392 MethodType . new (
341393 type_params : [ ] ,
342394 type : type ,
343- block : nil ,
395+ block : method_block ,
344396 location : nil
345397 )
346398 end
@@ -371,7 +423,7 @@ def self.build(leading_block, trailing_block, variables, node)
371423
372424 type_annotations = nil #: type_annotations
373425 return_annotation = nil #: Annotations::ReturnTypeAnnotation | Annotations::NodeTypeAssertion | nil
374- param_annotations = [ ] #: Array[Annotations::ParamTypeAnnotation | Annotations::SplatParamTypeAnnotation | Annotations::DoubleSplatParamTypeAnnotation]
426+ param_annotations = [ ] #: Array[Annotations::ParamTypeAnnotation | Annotations::SplatParamTypeAnnotation | Annotations::DoubleSplatParamTypeAnnotation | Annotations::BlockParamTypeAnnotation ]
375427
376428 if trailing_block
377429 case annotation = trailing_block . trailing_annotation ( variables )
@@ -405,7 +457,7 @@ def self.build(leading_block, trailing_block, variables, node)
405457 next
406458 end
407459 end
408- when Annotations ::ParamTypeAnnotation , Annotations ::SplatParamTypeAnnotation , Annotations ::DoubleSplatParamTypeAnnotation
460+ when Annotations ::ParamTypeAnnotation , Annotations ::SplatParamTypeAnnotation , Annotations ::DoubleSplatParamTypeAnnotation , Annotations :: BlockParamTypeAnnotation
409461 unless type_annotations
410462 param_annotations << paragraph
411463 next
0 commit comments