Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 79 additions & 41 deletions lib/rexml/xpath_parser.rb
Original file line number Diff line number Diff line change
Expand Up @@ -385,31 +385,63 @@ def calls_position_dependent_function?(expr)
expr.any? {|part| calls_position_dependent_function?(part) }
end

# Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [position() > 3]
# Returns operators and values such as [:==, 1], [:<, 2], [:>, 3]
# Detects simple position-based predicates that can be optimized in axis scanning, such as [1], [position()=1], [position() < 2], [last()-3], etc.
# Returns operators and values such as [:index_eq, 0], [:index_lt, 1], [:index_gt, 2], [:reverse_index_eq, 0], [:reverse_index_lt, 1], [:reverse_index_gt, 2]
# Returns nil if the predicate is not a simple position-based predicate
def position_operation(predicate_expr)
return [:==, predicate_expr[1]] if predicate_expr[0] == :literal && predicate_expr[1].is_a?(Integer)
return [:index_eq, predicate_expr[1] - 1] if predicate_expr[0] == :literal && predicate_expr[1].is_a?(Integer)

reverse_index = last_minus_integer(predicate_expr)
return [:reverse_index_eq, reverse_index] if reverse_index

op, left, right = predicate_expr
return unless op == :eq || op == :lt || op == :lteq || op == :gt || op == :gteq
return unless [left, right].include?([:function, 'position', []])

literal = [left, right].find {|part| part[0] == :literal && part[1].is_a?(Integer) }
return unless literal
if right == [:function, 'position', []]
op = { eq: :eq, lt: :gt, lteq: :gteq, gt: :lt, gteq: :lteq }[op]
left, right = right, left
end

value = literal[1]
case op
when :eq
[:==, value]
when :lt
literal == right ? [:<, value] : [:>, value]
when :lteq
literal == right ? [:<, value + 1] : [:>, value - 1]
when :gt
literal == right ? [:>, value]: [:<, value]
when :gteq
literal == right ? [:>, value - 1] : [:<, value + 1]
index = right[1] - 1 if right[0] == :literal && right[1].is_a?(Integer)
reverse_index = last_minus_integer(right)

if index
case op
when :eq
[:index_eq, index]
when :lt
[:index_lt, index]
when :lteq
[:index_lt, index + 1]
when :gt
[:index_gt, index]
when :gteq
[:index_gt, index - 1]
end
elsif reverse_index
case op
when :eq
[:reverse_index_eq, reverse_index]
when :lt
[:reverse_index_gt, reverse_index]
when :lteq
[:reverse_index_gt, reverse_index - 1]
when :gt
[:reverse_index_lt, reverse_index]
when :gteq
[:reverse_index_lt, reverse_index + 1]
end
end
end

# If the expression is `last()-INTEGER` or `last()` (equivalent to `last()-0`), returns the integer part.
# Otherwise, returns nil.
def last_minus_integer(expr)
if expr == [:function, 'last', []]
0
elsif expr[0] == :minus && expr[1] == [:function, 'last', []] && expr[2][0] == :literal && expr[2][1].is_a?(Integer)
expr[2][1]
end
end

Expand All @@ -435,7 +467,8 @@ def following_sibling(nodeset, tester, selector)

def preceding_following_sibling(nodeset, tester, selector, reverse:)
nodeset = nodeset.select {|node| node.respond_to?(:parent) && node.parent }
case selector
operator, value = selector
case operator
when :uniq
nodeset.group_by(&:parent).flat_map do |parent, sibling_nodes|
sets = Set.new.compare_by_identity
Expand All @@ -444,15 +477,7 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
children = children.reverse if reverse
children.drop_while {|child| !sets.include?(child) }.drop(1)
end.select(&tester)
when :nodesets
nodesets = nodeset.map do |node|
parent = node.parent
index = parent.children.index(node)
reverse ? parent.children[0...index].reverse : parent.children[index + 1..-1]
end
non_optimized_nodesets_select(nodesets, tester, selector)
else
operator, value = selector
when :index_eq, :index_lt, :index_gt
nodeset.group_by(&:parent).flat_map do |parent, sibling_nodes|
anchors = Set.new.compare_by_identity
sibling_nodes.each {|sibling| anchors << sibling }
Expand All @@ -466,16 +491,16 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
followings.each do |node|
if tester.call(node)
case operator
when :==
when :index_eq
# anchor_indexes only contain values smaller or equal to `index`,
# so value <= 0 case doesn't accidentally match any node.
matched << node if anchor_indexes.include?(index - value + 1)
when :<
# so value < 0 case doesn't accidentally match any node.
matched << node if anchor_indexes.include?(index - value)
when :index_lt
# Position from the last anchor will be the minimum possible position for the node
matched << node if index - last_anchor + 1 < value
when :>
matched << node if index - last_anchor < value
when :index_gt
# Position from the first anchor(==0) will be the maximum possible position for the node
matched << node if index + 1 > value
matched << node if index > value
end
index += 1
end
Expand All @@ -486,6 +511,13 @@ def preceding_following_sibling(nodeset, tester, selector, reverse:)
end
matched
end
else # Slow path for :nodesets, :reverse_index_eq, :reverse_index_lt, :reverse_index_gt
nodesets = nodeset.map do |node|
parent = node.parent
index = parent.children.index(node)
reverse ? parent.children[0...index].reverse : parent.children[index + 1..-1]
end
non_optimized_nodesets_select(nodesets, tester, selector)
end
end

Expand All @@ -506,7 +538,7 @@ def ancestor(nodeset, tester, selector, include_self: false)
end
ancestors.select(&tester)
else
# Slow pass
# Slow path
nodesets = nodeset.map do |node|
ancestors = []
ancestors << node if include_self
Expand Down Expand Up @@ -537,12 +569,18 @@ def non_optimized_nodesets_select(nodesets, tester, selector)
operator, value = selector
nodes =
case operator
when :==
nodesets.map {|nodeset| nodeset[value - 1] if value >= 1 }.compact
when :<
nodesets.flat_map {|nodeset| nodeset[0...value - 1] if value >= 1 }.compact
when :>
nodesets.flat_map {|nodeset| value <= 0 ? nodeset : nodeset.drop(value) }
when :index_eq
nodesets.map {|nodeset| nodeset[value] if value >= 0 }.compact
when :index_lt
nodesets.flat_map {|nodeset| nodeset[0...value] if value >= 0 }.compact
when :index_gt
nodesets.flat_map {|nodeset| value < 0 ? nodeset : nodeset.drop(value + 1) }
when :reverse_index_eq
nodesets.map {|nodeset| nodeset[-(value + 1)] if value >= 0 }.compact
when :reverse_index_lt
nodesets.flat_map {|nodeset| nodeset.last(value) if value > 0 }.compact
when :reverse_index_gt
nodesets.flat_map {|nodeset| value < 0 ? nodeset : nodeset[0...-(value + 1)] }
end
seen = Set.new.compare_by_identity
nodes.each {|node| seen << node }
Expand Down
14 changes: 14 additions & 0 deletions test/xpath/test_axis_preceding_sibling.rb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_preceding_following_sibling_multiple_anchors

assert_equal(%w[2 7 9], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[2 3 4 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() <= 3]").map {|n| n.attributes["id"] })
assert_equal(%w[2 3 4 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[4 > position()]").map {|n| n.attributes["id"] })
assert_equal(%w[1 2 3 4 5 6 7 8], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() >= 4]").map {|n| n.attributes["id"] })
assert_equal(%w[2 7 a2], XPath.match(doc, "/a/anchor/preceding-sibling::*[@id][position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[2 3 4 7 8 9 a2 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::*[@id][position() <= 3]").map {|n| n.attributes["id"] })
Expand All @@ -108,9 +109,22 @@ def test_preceding_following_sibling_multiple_anchors
assert_equal(%w[7 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[5 6 7 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() <= 3]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() >= 4]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[3 < position()]").map {|n| n.attributes["id"] })
assert_equal(%w[7 a3], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() = 3]").map {|n| n.attributes["id"] })
assert_equal(%w[5 6 7 10 11 a3 12], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() <= 3]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 a2 10 11 a3 12], XPath.match(doc, "/a/anchor/following-sibling::*[@id][position() >= 4]").map {|n| n.attributes["id"] })

assert_equal(%w[1], XPath.match(doc, "/a/anchor/preceding-sibling::b[last()]").map {|n| n.attributes["id"] })
assert_equal(%w[4], XPath.match(doc, "/a/anchor/preceding-sibling::b[last() - 3]").map {|n| n.attributes["id"] })
assert_equal(%w[4 5 6 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() <= last() - 3]").map {|n| n.attributes["id"] })
assert_equal(%w[4 5 6 7 8 9 10 11], XPath.match(doc, "/a/anchor/preceding-sibling::b[last() - 2 > position()]").map {|n| n.attributes["id"] })
assert_equal(%w[1 2 3 4 5], XPath.match(doc, "/a/anchor/preceding-sibling::b[position() >= last() - 4]").map {|n| n.attributes["id"] })

assert_equal(%w[12], XPath.match(doc, "/a/anchor/following-sibling::b[last()]").map {|n| n.attributes["id"] })
assert_equal(%w[9], XPath.match(doc, "/a/anchor/following-sibling::b[last() - 3]").map {|n| n.attributes["id"] })
assert_equal(%w[5 6 7 8 9], XPath.match(doc, "/a/anchor/following-sibling::b[position() <= last() - 3]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[position() >= last() - 4]").map {|n| n.attributes["id"] })
assert_equal(%w[8 9 10 11 12], XPath.match(doc, "/a/anchor/following-sibling::b[last() - 5 < position()]").map {|n| n.attributes["id"] })
end
end
end
6 changes: 6 additions & 0 deletions test/xpath/test_predicate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def test_predicate_out_of_range_position
assert_equal(%w[a b c d], parser.parse("#{base}[position()>0]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base}[position()>-1]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base}[position()<10]", doc).map(&:name))
assert_equal(%w[], parser.parse("#{base}[position()<last()-10]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base}[position()>last()-10]", doc).map(&:name))
assert_equal(%w[], parser.parse("#{base}[last()-10]", doc).map(&:name))

# non-optimizable case
base_no_opt = '/r/*[position()!=name()]'
Expand All @@ -128,6 +131,9 @@ def test_predicate_out_of_range_position
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()>0]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()>-1]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()<10]", doc).map(&:name))
assert_equal(%w[], parser.parse("#{base_no_opt}[position()<last()-10]", doc).map(&:name))
assert_equal(%w[a b c d], parser.parse("#{base_no_opt}[position()>last()-10]", doc).map(&:name))
assert_equal(%w[], parser.parse("#{base_no_opt}[last()-10]", doc).map(&:name))
end

def test_predicate_parenthesized_position
Expand Down
Loading