@@ -385,31 +385,63 @@ def calls_position_dependent_function?(expr)
385385 expr . any? { |part | calls_position_dependent_function? ( part ) }
386386 end
387387
388- # Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [position() > 3]
389- # Returns operators and values such as [:==, 1], [:< , 2], [:>, 3 ]
388+ # Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [last()-3], etc.
389+ # Returns operators and values such as [:index_eq, 0], [:index_lt, 1], [:index_gt , 2], [:reverse_index_eq, 0], [:reverse_index_lt, 1], [:reverse_index_gt, 2 ]
390390 # Returns nil if the predicate is not a simple position-based predicate
391391 def position_operation ( predicate_expr )
392- return [ :== , predicate_expr [ 1 ] ] if predicate_expr [ 0 ] == :literal && predicate_expr [ 1 ] . is_a? ( Integer )
392+ return [ :index_eq , predicate_expr [ 1 ] - 1 ] if predicate_expr [ 0 ] == :literal && predicate_expr [ 1 ] . is_a? ( Integer )
393+
394+ reverse_index = last_minus_integer ( predicate_expr )
395+ return [ :reverse_index_eq , reverse_index ] if reverse_index
393396
394397 op , left , right = predicate_expr
395398 return unless op == :eq || op == :lt || op == :lteq || op == :gt || op == :gteq
396399 return unless [ left , right ] . include? ( [ :function , 'position' , [ ] ] )
397400
398- literal = [ left , right ] . find { |part | part [ 0 ] == :literal && part [ 1 ] . is_a? ( Integer ) }
399- return unless literal
401+ if right == [ :function , 'position' , [ ] ]
402+ op = { eq : :eq , lt : :gt , lteq : :gteq , gt : :lt , gteq : :lteq } [ op ]
403+ left , right = right , left
404+ end
400405
401- value = literal [ 1 ]
402- case op
403- when :eq
404- [ :== , value ]
405- when :lt
406- literal == right ? [ :< , value ] : [ :> , value ]
407- when :lteq
408- literal == right ? [ :< , value + 1 ] : [ :> , value - 1 ]
409- when :gt
410- literal == right ? [ :> , value ] : [ :< , value ]
411- when :gteq
412- literal == right ? [ :> , value - 1 ] : [ :< , value + 1 ]
406+ index = right [ 1 ] - 1 if right [ 0 ] == :literal && right [ 1 ] . is_a? ( Integer )
407+ reverse_index = last_minus_integer ( right )
408+
409+ if index
410+ case op
411+ when :eq
412+ [ :index_eq , index ]
413+ when :lt
414+ [ :index_lt , index ]
415+ when :lteq
416+ [ :index_lt , index + 1 ]
417+ when :gt
418+ [ :index_gt , index ]
419+ when :gteq
420+ [ :index_gt , index - 1 ]
421+ end
422+ elsif reverse_index
423+ case op
424+ when :eq
425+ [ :reverse_index_eq , reverse_index ]
426+ when :lt
427+ [ :reverse_index_gt , reverse_index ]
428+ when :lteq
429+ [ :reverse_index_gt , reverse_index - 1 ]
430+ when :gt
431+ [ :reverse_index_lt , reverse_index ]
432+ when :gteq
433+ [ :reverse_index_lt , reverse_index + 1 ]
434+ end
435+ end
436+ end
437+
438+ # If the expression is `last()-INTEGER` or `last()` (equivalent to `last()-0`), returns the integer part.
439+ # Otherwise, returns nil.
440+ def last_minus_integer ( expr )
441+ if expr == [ :function , 'last' , [ ] ]
442+ 0
443+ elsif expr [ 0 ] == :minus && expr [ 1 ] == [ :function , 'last' , [ ] ] && expr [ 2 ] [ 0 ] == :literal && expr [ 2 ] [ 1 ] . is_a? ( Integer )
444+ expr [ 2 ] [ 1 ]
413445 end
414446 end
415447
@@ -435,7 +467,8 @@ def following_sibling(nodeset, tester, selector)
435467
436468 def preceding_following_sibling ( nodeset , tester , selector , reverse :)
437469 nodeset = nodeset . select { |node | node . respond_to? ( :parent ) && node . parent }
438- case selector
470+ operator , value = selector
471+ case operator
439472 when :uniq
440473 nodeset . group_by ( &:parent ) . flat_map do |parent , sibling_nodes |
441474 sets = Set . new . compare_by_identity
@@ -444,15 +477,7 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
444477 children = children . reverse if reverse
445478 children . drop_while { |child | !sets . include? ( child ) } . drop ( 1 )
446479 end . select ( &tester )
447- when :nodesets
448- nodesets = nodeset . map do |node |
449- parent = node . parent
450- index = parent . children . index ( node )
451- reverse ? parent . children [ 0 ...index ] . reverse : parent . children [ index + 1 ..-1 ]
452- end
453- non_optimized_nodesets_select ( nodesets , tester , selector )
454- else
455- operator , value = selector
480+ when :index_eq , :index_lt , :index_gt
456481 nodeset . group_by ( &:parent ) . flat_map do |parent , sibling_nodes |
457482 anchors = Set . new . compare_by_identity
458483 sibling_nodes . each { |sibling | anchors << sibling }
@@ -466,16 +491,16 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
466491 followings . each do |node |
467492 if tester . call ( node )
468493 case operator
469- when :==
494+ when :index_eq
470495 # anchor_indexes only contain values smaller or equal to `index`,
471- # so value <= 0 case doesn't accidentally match any node.
472- matched << node if anchor_indexes . include? ( index - value + 1 )
473- when :<
496+ # so value < 0 case doesn't accidentally match any node.
497+ matched << node if anchor_indexes . include? ( index - value )
498+ when :index_lt
474499 # Position from the last anchor will be the minimum possible position for the node
475- matched << node if index - last_anchor + 1 < value
476- when :>
500+ matched << node if index - last_anchor < value
501+ when :index_gt
477502 # Position from the first anchor(==0) will be the maximum possible position for the node
478- matched << node if index + 1 > value
503+ matched << node if index > value
479504 end
480505 index += 1
481506 end
@@ -486,6 +511,13 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
486511 end
487512 matched
488513 end
514+ else # Slow path for :nodesets, :reverse_index_eq, :reverse_index_lt, :reverse_index_gt
515+ nodesets = nodeset . map do |node |
516+ parent = node . parent
517+ index = parent . children . index ( node )
518+ reverse ? parent . children [ 0 ...index ] . reverse : parent . children [ index + 1 ..-1 ]
519+ end
520+ non_optimized_nodesets_select ( nodesets , tester , selector )
489521 end
490522 end
491523
@@ -506,7 +538,7 @@ def ancestor(nodeset, tester, selector, include_self: false)
506538 end
507539 ancestors . select ( &tester )
508540 else
509- # Slow pass
541+ # Slow path
510542 nodesets = nodeset . map do |node |
511543 ancestors = [ ]
512544 ancestors << node if include_self
@@ -537,12 +569,18 @@ def non_optimized_nodesets_select(nodesets, tester, selector)
537569 operator , value = selector
538570 nodes =
539571 case operator
540- when :==
541- nodesets . map { |nodeset | nodeset [ value - 1 ] if value >= 1 } . compact
542- when :<
543- nodesets . flat_map { |nodeset | nodeset [ 0 ...value - 1 ] if value >= 1 } . compact
544- when :>
545- nodesets . flat_map { |nodeset | value <= 0 ? nodeset : nodeset . drop ( value ) }
572+ when :index_eq
573+ nodesets . map { |nodeset | nodeset [ value ] if value >= 0 } . compact
574+ when :index_lt
575+ nodesets . flat_map { |nodeset | nodeset [ 0 ...value ] if value >= 0 } . compact
576+ when :index_gt
577+ nodesets . flat_map { |nodeset | value < 0 ? nodeset : nodeset . drop ( value + 1 ) }
578+ when :reverse_index_eq
579+ nodesets . map { |nodeset | nodeset [ -( value + 1 ) ] if value >= 0 } . compact
580+ when :reverse_index_lt
581+ nodesets . flat_map { |nodeset | nodeset [ -value ..-1 ] if value > 0 } . compact
582+ when :reverse_index_gt
583+ nodesets . flat_map { |nodeset | value < 0 ? nodeset : nodeset [ 0 ...-( value + 1 ) ] }
546584 end
547585 seen = Set . new . compare_by_identity
548586 nodes . each { |node | seen << node }
0 commit comments