Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions lib/code_to_query/compiler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
end

module CodeToQuery
# rubocop:disable Metrics/ClassLength
class Compiler
def initialize(config)
@config = config
Expand Down Expand Up @@ -182,6 +183,18 @@ def compile_with_arel(intent)
end
end

if (having_filters = intent['having']).present?
having_filters.each do |h|
agg_node = build_arel_aggregate(table, h)
next unless agg_node

key = h['param'] || "having_#{h['column']}"
bind_spec << { key: key, column: h['column'], cast: nil }
condition = build_arel_having_condition(agg_node, h['op'], key)
query = query.having(condition) if condition
end
end

if (limit = determine_appropriate_limit(intent))
query = query.take(limit)
end
Expand All @@ -197,7 +210,7 @@ def compile_with_arel(intent)
compile_with_string_building(intent)
end

# rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength
# rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength, Metrics/CyclomaticComplexity
# NOTE: This method is intentionally monolithic for clarity and to avoid regressions in SQL assembly.
# TODO: Extract EXISTS/NOT EXISTS handling and simple predicate building into small helpers.
def compile_with_string_building(intent)
Expand Down Expand Up @@ -300,6 +313,8 @@ def compile_with_string_building(intent)
sub_where << "#{rcol} BETWEEN #{placeholder1} AND #{placeholder2}"
when 'in'
key = rf['param'] || rf['column']
values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym]
validate_in_clause_values!(values, rf['column'])
placeholder = placeholder_for_adapter(placeholder_index)
bind_spec << ({ key: key, column: rf['column'], cast: :array })
placeholder_index += 1
Expand Down Expand Up @@ -363,6 +378,8 @@ def compile_with_string_building(intent)
sub_where << "#{rcol} BETWEEN #{placeholder1} AND #{placeholder2}"
when 'in'
key = rf['param'] || rf['column']
values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym]
validate_in_clause_values!(values, rf['column'])
placeholder = placeholder_for_adapter(placeholder_index)
bind_spec << { key: key, column: rf['column'], cast: :array }
placeholder_index += 1
Expand Down Expand Up @@ -394,7 +411,11 @@ def compile_with_string_building(intent)
"#{col} BETWEEN #{placeholder1} AND #{placeholder2}"
when 'in'
key = f['param'] || f['column']
# For IN clauses, we'll need to handle arrays specially
values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym]
if values.is_a?(Array) && values.empty?
raise ArgumentError, "IN clause requires non-empty array for column '#{f['column']}'"
end

placeholder = placeholder_for_adapter(placeholder_index)
bind_spec << { key: key, column: f['column'], cast: :array }
placeholder_index += 1
Expand All @@ -412,6 +433,22 @@ def compile_with_string_building(intent)
sql_parts << "WHERE #{where_fragments.join(' AND ')}" if where_fragments.any?
end

if (group_columns = intent['group_by']).present?
group_fragments = group_columns.map { |col| quote_ident(col) }
sql_parts << "GROUP BY #{group_fragments.join(', ')}"
end

if (having_filters = intent['having']).present?
having_fragments = having_filters.map do |h|
agg_expr = build_aggregate_expression(h)
placeholder = placeholder_for_adapter(placeholder_index)
bind_spec << { key: h['param'] || "having_#{h['column']}", column: h['column'], cast: nil }
placeholder_index += 1
"#{agg_expr} #{h['op']} #{placeholder}"
end
sql_parts << "HAVING #{having_fragments.join(' AND ')}"
end

if (orders = intent['order']).present?
order_fragments = orders.map do |o|
dir = o['dir'].to_s.downcase == 'desc' ? 'DESC' : 'ASC'
Expand All @@ -426,7 +463,7 @@ def compile_with_string_building(intent)

{ sql: sql_parts.join(' '), params: params_hash, bind_spec: bind_spec }
end
# rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength
# rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength, Metrics/CyclomaticComplexity

def apply_policy_in_subquery(sub_where, bind_spec, related_table, placeholder_index)
return [sub_where, placeholder_index] unless @config.policy_adapter.respond_to?(:call)
Expand Down Expand Up @@ -614,13 +651,86 @@ def parse_function_column(expr)
s = expr.to_s.strip
return nil unless s.include?('(') && s.end_with?(')')

if (m = s.match(/\A\s*(count|sum|avg|max|min)\s*\(\s*(\*|[a-zA-Z0-9_\.]+)\s*\)\s*\z/i))
if (m = s.match(/\A\s*(count|sum|avg|max|min)\s*\(\s*(\*|[a-zA-Z0-9_.]+)\s*\)\s*\z/i))
func = m[1].downcase
col = m[2] == '*' ? nil : m[2]
{ func: func, column: col }
end
end

def build_aggregate_expression(having_spec)
func = having_spec['function'].to_s.upcase
col = having_spec['column']

case func
when 'COUNT'
col ? "COUNT(#{quote_ident(col)})" : 'COUNT(*)'
when 'SUM'
"SUM(#{quote_ident(col)})"
when 'AVG'
"AVG(#{quote_ident(col)})"
when 'MAX'
"MAX(#{quote_ident(col)})"
when 'MIN'
"MIN(#{quote_ident(col)})"
else
'COUNT(*)'
end
end

def validate_in_clause_values!(values, column)
return unless values.is_a?(Array) && values.empty?

raise ArgumentError, "IN clause requires non-empty array for column '#{column}'"
end

def build_arel_aggregate(table, having_spec)
func = having_spec['function'].to_s.downcase
col = having_spec['column']

case func
when 'count'
col ? table[col].count : Arel.star.count
when 'sum'
return nil unless col

table[col].sum
when 'avg'
return nil unless col

table[col].average
when 'max'
return nil unless col

table[col].maximum
when 'min'
return nil unless col

table[col].minimum
else
Arel.star.count
end
end

def build_arel_having_condition(agg_node, operator, key)
bind_param = Arel::Nodes::BindParam.new(key)

case operator
when '='
agg_node.eq(bind_param)
when '!='
agg_node.not_eq(bind_param)
when '>'
agg_node.gt(bind_param)
when '>='
agg_node.gteq(bind_param)
when '<'
agg_node.lt(bind_param)
when '<='
agg_node.lteq(bind_param)
end
end

def normalize_params_with_model(intent)
params = (intent['params'] || {}).dup
return params unless defined?(ActiveRecord::Base)
Expand Down Expand Up @@ -671,4 +781,5 @@ def infer_model_for_table(table_name)
nil
end
end
# rubocop:enable Metrics/ClassLength
end
Loading