From 15921d7075716d12f16fb91eef5615491692b491 Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Tue, 6 Jan 2026 04:47:33 +0200 Subject: [PATCH] fix SQL generation issues and add comprehensive tests - Add GROUP BY support in string builder compilation path - Implement HAVING clause support for both Arel and string builder - Add validation to prevent empty array IN clauses (syntax error) - Add comprehensive tests for compiler, validator, runner, explain_gate - Fix query_spec tests for safe? and binds methods --- lib/code_to_query/compiler.rb | 119 ++++- spec/code_to_query/compiler_spec.rb | 252 +++++++++ .../guardrails/explain_gate_spec.rb | 494 ++++++++++++++++++ spec/code_to_query/query_spec.rb | 96 +++- spec/code_to_query/runner_spec.rb | 265 ++++++++++ spec/code_to_query/validator_spec.rb | 374 +++++++++++++ 6 files changed, 1584 insertions(+), 16 deletions(-) create mode 100644 spec/code_to_query/guardrails/explain_gate_spec.rb create mode 100644 spec/code_to_query/runner_spec.rb create mode 100644 spec/code_to_query/validator_spec.rb diff --git a/lib/code_to_query/compiler.rb b/lib/code_to_query/compiler.rb index 0a5f54a..f514532 100644 --- a/lib/code_to_query/compiler.rb +++ b/lib/code_to_query/compiler.rb @@ -9,6 +9,7 @@ end module CodeToQuery + # rubocop:disable Metrics/ClassLength class Compiler def initialize(config) @config = config @@ -182,6 +183,18 @@ def compile_with_arel(intent) end end + if (having_filters = intent['having']).present? + having_filters.each do |h| + agg_node = build_arel_aggregate(table, h) + next unless agg_node + + key = h['param'] || "having_#{h['column']}" + bind_spec << { key: key, column: h['column'], cast: nil } + condition = build_arel_having_condition(agg_node, h['op'], key) + query = query.having(condition) if condition + end + end + if (limit = determine_appropriate_limit(intent)) query = query.take(limit) end @@ -197,7 +210,7 @@ def compile_with_arel(intent) compile_with_string_building(intent) end - # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength + # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength, Metrics/CyclomaticComplexity # NOTE: This method is intentionally monolithic for clarity and to avoid regressions in SQL assembly. # TODO: Extract EXISTS/NOT EXISTS handling and simple predicate building into small helpers. def compile_with_string_building(intent) @@ -300,6 +313,8 @@ def compile_with_string_building(intent) sub_where << "#{rcol} BETWEEN #{placeholder1} AND #{placeholder2}" when 'in' key = rf['param'] || rf['column'] + values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym] + validate_in_clause_values!(values, rf['column']) placeholder = placeholder_for_adapter(placeholder_index) bind_spec << ({ key: key, column: rf['column'], cast: :array }) placeholder_index += 1 @@ -363,6 +378,8 @@ def compile_with_string_building(intent) sub_where << "#{rcol} BETWEEN #{placeholder1} AND #{placeholder2}" when 'in' key = rf['param'] || rf['column'] + values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym] + validate_in_clause_values!(values, rf['column']) placeholder = placeholder_for_adapter(placeholder_index) bind_spec << { key: key, column: rf['column'], cast: :array } placeholder_index += 1 @@ -394,7 +411,11 @@ def compile_with_string_building(intent) "#{col} BETWEEN #{placeholder1} AND #{placeholder2}" when 'in' key = f['param'] || f['column'] - # For IN clauses, we'll need to handle arrays specially + values = params_hash[key] || params_hash[key.to_s] || params_hash[key.to_sym] + if values.is_a?(Array) && values.empty? + raise ArgumentError, "IN clause requires non-empty array for column '#{f['column']}'" + end + placeholder = placeholder_for_adapter(placeholder_index) bind_spec << { key: key, column: f['column'], cast: :array } placeholder_index += 1 @@ -412,6 +433,22 @@ def compile_with_string_building(intent) sql_parts << "WHERE #{where_fragments.join(' AND ')}" if where_fragments.any? end + if (group_columns = intent['group_by']).present? + group_fragments = group_columns.map { |col| quote_ident(col) } + sql_parts << "GROUP BY #{group_fragments.join(', ')}" + end + + if (having_filters = intent['having']).present? + having_fragments = having_filters.map do |h| + agg_expr = build_aggregate_expression(h) + placeholder = placeholder_for_adapter(placeholder_index) + bind_spec << { key: h['param'] || "having_#{h['column']}", column: h['column'], cast: nil } + placeholder_index += 1 + "#{agg_expr} #{h['op']} #{placeholder}" + end + sql_parts << "HAVING #{having_fragments.join(' AND ')}" + end + if (orders = intent['order']).present? order_fragments = orders.map do |o| dir = o['dir'].to_s.downcase == 'desc' ? 'DESC' : 'ASC' @@ -426,7 +463,7 @@ def compile_with_string_building(intent) { sql: sql_parts.join(' '), params: params_hash, bind_spec: bind_spec } end - # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength + # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/BlockLength, Metrics/CyclomaticComplexity def apply_policy_in_subquery(sub_where, bind_spec, related_table, placeholder_index) return [sub_where, placeholder_index] unless @config.policy_adapter.respond_to?(:call) @@ -614,13 +651,86 @@ def parse_function_column(expr) s = expr.to_s.strip return nil unless s.include?('(') && s.end_with?(')') - if (m = s.match(/\A\s*(count|sum|avg|max|min)\s*\(\s*(\*|[a-zA-Z0-9_\.]+)\s*\)\s*\z/i)) + if (m = s.match(/\A\s*(count|sum|avg|max|min)\s*\(\s*(\*|[a-zA-Z0-9_.]+)\s*\)\s*\z/i)) func = m[1].downcase col = m[2] == '*' ? nil : m[2] { func: func, column: col } end end + def build_aggregate_expression(having_spec) + func = having_spec['function'].to_s.upcase + col = having_spec['column'] + + case func + when 'COUNT' + col ? "COUNT(#{quote_ident(col)})" : 'COUNT(*)' + when 'SUM' + "SUM(#{quote_ident(col)})" + when 'AVG' + "AVG(#{quote_ident(col)})" + when 'MAX' + "MAX(#{quote_ident(col)})" + when 'MIN' + "MIN(#{quote_ident(col)})" + else + 'COUNT(*)' + end + end + + def validate_in_clause_values!(values, column) + return unless values.is_a?(Array) && values.empty? + + raise ArgumentError, "IN clause requires non-empty array for column '#{column}'" + end + + def build_arel_aggregate(table, having_spec) + func = having_spec['function'].to_s.downcase + col = having_spec['column'] + + case func + when 'count' + col ? table[col].count : Arel.star.count + when 'sum' + return nil unless col + + table[col].sum + when 'avg' + return nil unless col + + table[col].average + when 'max' + return nil unless col + + table[col].maximum + when 'min' + return nil unless col + + table[col].minimum + else + Arel.star.count + end + end + + def build_arel_having_condition(agg_node, operator, key) + bind_param = Arel::Nodes::BindParam.new(key) + + case operator + when '=' + agg_node.eq(bind_param) + when '!=' + agg_node.not_eq(bind_param) + when '>' + agg_node.gt(bind_param) + when '>=' + agg_node.gteq(bind_param) + when '<' + agg_node.lt(bind_param) + when '<=' + agg_node.lteq(bind_param) + end + end + def normalize_params_with_model(intent) params = (intent['params'] || {}).dup return params unless defined?(ActiveRecord::Base) @@ -671,4 +781,5 @@ def infer_model_for_table(table_name) nil end end + # rubocop:enable Metrics/ClassLength end diff --git a/spec/code_to_query/compiler_spec.rb b/spec/code_to_query/compiler_spec.rb index 423fbf8..afa0e9f 100644 --- a/spec/code_to_query/compiler_spec.rb +++ b/spec/code_to_query/compiler_spec.rb @@ -209,5 +209,257 @@ expect(result[:sql]).to include('SELECT "id", "email", "created_at"') end end + + context 'with GROUP BY clause' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['user_id', 'COUNT(*)'], + 'group_by' => ['user_id'], + 'limit' => 100 + } + end + + it 'generates GROUP BY clause' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('GROUP BY "user_id"') + expect(result[:sql]).to include('COUNT(*)') + end + end + + context 'with GROUP BY on multiple columns' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['user_id', 'status', 'COUNT(*)'], + 'group_by' => %w[user_id status], + 'limit' => 100 + } + end + + it 'generates GROUP BY with multiple columns' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('GROUP BY "user_id", "status"') + end + end + + context 'with HAVING clause' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['user_id', 'COUNT(*)'], + 'group_by' => ['user_id'], + 'having' => [ + { 'function' => 'count', 'column' => nil, 'op' => '>', 'param' => 'min_orders' } + ], + 'limit' => 100, + 'params' => { 'min_orders' => 5 } + } + end + + it 'generates HAVING clause with COUNT' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('GROUP BY "user_id"') + expect(result[:sql]).to include('HAVING COUNT(*) > $1') + expect(result[:bind_spec]).to include(hash_including(key: 'min_orders')) + end + end + + context 'with HAVING clause using SUM' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['user_id', 'SUM(amount)'], + 'group_by' => ['user_id'], + 'having' => [ + { 'function' => 'sum', 'column' => 'amount', 'op' => '>=', 'param' => 'min_amount' } + ], + 'limit' => 100, + 'params' => { 'min_amount' => 1000 } + } + end + + it 'generates HAVING clause with SUM' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('HAVING SUM("amount") >= $1') + expect(result[:bind_spec]).to include(hash_including(key: 'min_amount')) + end + end + + context 'with IN clause' do + let(:intent) do + { + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'status', 'op' => 'in', 'param' => 'statuses' } + ], + 'limit' => 100, + 'params' => { 'statuses' => %w[active pending] } + } + end + + it 'generates IN clause' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('"status" IN ($1)') + expect(result[:bind_spec]).to include(hash_including(key: 'statuses', cast: :array)) + end + end + + context 'with empty array in IN clause' do + let(:intent) do + { + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'status', 'op' => 'in', 'param' => 'statuses' } + ], + 'limit' => 100, + 'params' => { 'statuses' => [] } + } + end + + it 'raises an error for empty array' do + expect { compiler.compile(intent) }.to raise_error( + ArgumentError, /IN clause requires non-empty array/ + ) + end + end + + context 'with LIKE filter' do + let(:intent) do + { + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'email', 'op' => 'like', 'param' => 'email_pattern' } + ], + 'limit' => 100, + 'params' => { 'email_pattern' => '%@example.com' } + } + end + + it 'generates LIKE clause' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('"email" LIKE $1') + end + end + + context 'with ILIKE filter (PostgreSQL)' do + let(:intent) do + { + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'name', 'op' => 'ilike', 'param' => 'name_pattern' } + ], + 'limit' => 100, + 'params' => { 'name_pattern' => '%john%' } + } + end + + it 'generates ILIKE clause' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('"name" ILIKE $1') + end + end + + context 'with aggregate functions' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['SUM(amount)'], + 'limit' => 100 + } + end + + it 'generates SUM aggregate' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('SUM("amount") as sum') + end + end + + context 'with AVG aggregate' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['AVG(amount)'], + 'limit' => 100 + } + end + + it 'generates AVG aggregate' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('AVG("amount") as avg') + end + end + + context 'with multiple filters combined' do + let(:intent) do + { + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'active', 'op' => '=', 'param' => 'is_active' }, + { 'column' => 'age', 'op' => '>=', 'param' => 'min_age' }, + { 'column' => 'role', 'op' => 'in', 'param' => 'roles' } + ], + 'limit' => 100, + 'params' => { + 'is_active' => true, + 'min_age' => 18, + 'roles' => %w[admin moderator] + } + } + end + + it 'generates WHERE with AND logic' do + result = compiler.compile(intent) + + expect(result[:sql]).to include('"active" = $1') + expect(result[:sql]).to include('"age" >= $2') + expect(result[:sql]).to include('"role" IN ($3)') + expect(result[:sql]).to include(' AND ') + end + end + + context 'with GROUP BY, HAVING, and ORDER BY combined' do + let(:intent) do + { + 'table' => 'orders', + 'columns' => ['user_id', 'COUNT(*)'], + 'group_by' => ['user_id'], + 'having' => [ + { 'function' => 'count', 'column' => nil, 'op' => '>', 'param' => 'min_orders' } + ], + 'order' => [{ 'column' => 'user_id', 'dir' => 'asc' }], + 'limit' => 100, + 'params' => { 'min_orders' => 10 } + } + end + + it 'generates correct clause order' do + result = compiler.compile(intent) + sql = result[:sql] + + group_pos = sql.index('GROUP BY') + having_pos = sql.index('HAVING') + order_pos = sql.index('ORDER BY') + limit_pos = sql.index('LIMIT') + + expect(group_pos).to be < having_pos + expect(having_pos).to be < order_pos + expect(order_pos).to be < limit_pos + end + end end end diff --git a/spec/code_to_query/guardrails/explain_gate_spec.rb b/spec/code_to_query/guardrails/explain_gate_spec.rb new file mode 100644 index 0000000..ff61347 --- /dev/null +++ b/spec/code_to_query/guardrails/explain_gate_spec.rb @@ -0,0 +1,494 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe CodeToQuery::Guardrails::ExplainGate do + let(:config) do + stub_config( + adapter: :postgres, + max_query_cost: 10_000, + max_query_rows: 100_000, + allow_seq_scans: false, + explain_fail_open: true + ) + end + let(:gate) { described_class.new(config) } + + describe '#allowed?' do + context 'when ActiveRecord is not available' do + before do + hide_const('ActiveRecord::Base') if defined?(ActiveRecord::Base) + end + + it 'returns true (allows query)' do + expect(gate.allowed?('SELECT * FROM users')).to be true + end + end + + context 'when explain plan is empty' do + before do + ar_base = Class.new do + def self.connected? + true + end + end + stub_const('ActiveRecord::Base', ar_base) + allow(gate).to receive(:get_explain_plan).and_return([]) + end + + it 'returns true' do + expect(gate.allowed?('SELECT * FROM users')).to be true + end + end + + context 'when explain plan is nil' do + before do + ar_base = Class.new do + def self.connected? + true + end + end + stub_const('ActiveRecord::Base', ar_base) + allow(gate).to receive(:get_explain_plan).and_return(nil) + end + + it 'returns true' do + expect(gate.allowed?('SELECT * FROM users')).to be true + end + end + + context 'when an error occurs and explain_fail_open is true' do + before do + ar_base = Class.new do + def self.connected? + true + end + end + stub_const('ActiveRecord::Base', ar_base) + allow(gate).to receive(:get_explain_plan).and_raise(StandardError, 'connection error') + end + + it 'returns true (fail-open)' do + expect(gate.allowed?('SELECT * FROM users')).to be true + end + end + + context 'when an error occurs and explain_fail_open is false' do + before do + config.explain_fail_open = false + ar_base = Class.new do + def self.connected? + true + end + end + stub_const('ActiveRecord::Base', ar_base) + allow(gate).to receive(:get_explain_plan).and_raise(StandardError, 'connection error') + end + + it 'returns false (fail-closed)' do + expect(gate.allowed?('SELECT * FROM users')).to be false + end + end + end + + describe '#build_explain_query' do + context 'with PostgreSQL adapter' do + before { config.adapter = :postgres } + + it 'builds EXPLAIN with JSON format' do + sql = gate.send(:build_explain_query, 'SELECT * FROM users') + expect(sql).to include('EXPLAIN (ANALYZE false, BUFFERS false, VERBOSE false, FORMAT JSON)') + expect(sql).to include('SELECT * FROM users') + end + end + + context 'with MySQL adapter' do + before { config.adapter = :mysql } + + it 'builds EXPLAIN with JSON format' do + sql = gate.send(:build_explain_query, 'SELECT * FROM users') + expect(sql).to eq('EXPLAIN FORMAT=JSON SELECT * FROM users') + end + end + + context 'with SQLite adapter' do + before { config.adapter = :sqlite } + + it 'builds EXPLAIN QUERY PLAN' do + sql = gate.send(:build_explain_query, 'SELECT * FROM users') + expect(sql).to eq('EXPLAIN QUERY PLAN SELECT * FROM users') + end + end + + context 'with unknown adapter' do + before { config.adapter = :unknown } + + it 'builds simple EXPLAIN' do + sql = gate.send(:build_explain_query, 'SELECT * FROM users') + expect(sql).to eq('EXPLAIN SELECT * FROM users') + end + end + end + + describe '#normalize_explain_result' do + it 'handles array of hashes' do + result = [{ 'Plan' => 'something' }] + normalized = gate.send(:normalize_explain_result, result) + expect(normalized).to eq([{ 'Plan' => 'something' }]) + end + + it 'handles array of arrays' do + result = [['plan text']] + normalized = gate.send(:normalize_explain_result, result) + expect(normalized).to eq(['plan text']) + end + + it 'handles non-array result' do + result = 'plan text' + normalized = gate.send(:normalize_explain_result, result) + expect(normalized).to eq(['plan text']) + end + end + + describe '#check_node_safety' do + context 'when cost exceeds limit' do + let(:node) do + { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 15_000, + 'Plan Rows' => 100 + } + end + + it 'returns false' do + expect(gate.send(:check_node_safety, node)).to be false + end + end + + context 'when rows exceed limit' do + let(:node) do + { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 100, + 'Plan Rows' => 200_000 + } + end + + it 'returns false' do + expect(gate.send(:check_node_safety, node)).to be false + end + end + + context 'with Seq Scan on large table' do + let(:node) do + { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 100, + 'Plan Rows' => 5000 + } + end + + it 'returns false when seq scans are not allowed' do + expect(gate.send(:check_node_safety, node)).to be false + end + + it 'returns true when seq scans are allowed' do + config.allow_seq_scans = true + expect(gate.send(:check_node_safety, node)).to be true + end + end + + context 'with Seq Scan on small table' do + let(:node) do + { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 10, + 'Plan Rows' => 500 + } + end + + it 'returns true even when seq scans are not allowed' do + expect(gate.send(:check_node_safety, node)).to be true + end + end + + context 'with expensive Nested Loop' do + let(:node) do + { + 'Node Type' => 'Nested Loop', + 'Total Cost' => 500, + 'Plan Rows' => 50_000 + } + end + + it 'returns false' do + expect(gate.send(:check_node_safety, node)).to be false + end + end + + context 'with cheap Nested Loop' do + let(:node) do + { + 'Node Type' => 'Nested Loop', + 'Total Cost' => 100, + 'Plan Rows' => 1000 + } + end + + it 'returns true' do + expect(gate.send(:check_node_safety, node)).to be true + end + end + + context 'with child nodes' do + let(:parent_node) do + { + 'Node Type' => 'Hash Join', + 'Total Cost' => 100, + 'Plan Rows' => 100, + 'Plans' => [ + { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 50, + 'Plan Rows' => 5000 + } + ] + } + end + + it 'recursively checks child nodes' do + expect(gate.send(:check_node_safety, parent_node)).to be false + end + end + + context 'with safe child nodes' do + let(:parent_node) do + { + 'Node Type' => 'Hash Join', + 'Total Cost' => 100, + 'Plan Rows' => 100, + 'Plans' => [ + { + 'Node Type' => 'Index Scan', + 'Total Cost' => 10, + 'Plan Rows' => 100 + } + ] + } + end + + it 'returns true' do + expect(gate.send(:check_node_safety, parent_node)).to be true + end + end + + context 'with nil node' do + it 'returns true' do + expect(gate.send(:check_node_safety, nil)).to be true + end + end + + context 'with non-hash node' do + it 'returns true' do + expect(gate.send(:check_node_safety, 'string')).to be true + end + end + end + + describe '#analyze_postgres_json_plan' do + context 'with valid JSON plan' do + let(:plan) do + [{ + 'QUERY PLAN' => [{ + 'Plan' => { + 'Node Type' => 'Index Scan', + 'Total Cost' => 100, + 'Plan Rows' => 50 + } + }] + }] + end + + it 'returns true for safe plan' do + expect(gate.send(:analyze_postgres_json_plan, plan.first['QUERY PLAN'])).to be true + end + end + + context 'with expensive plan' do + let(:plan) do + [{ + 'Plan' => { + 'Node Type' => 'Seq Scan', + 'Total Cost' => 50_000, + 'Plan Rows' => 1_000_000 + } + }] + end + + it 'returns false' do + expect(gate.send(:analyze_postgres_json_plan, plan)).to be false + end + end + + context 'with invalid plan structure' do + it 'returns true for non-array' do + expect(gate.send(:analyze_postgres_json_plan, 'invalid')).to be true + end + + it 'returns true for array without hash' do + expect(gate.send(:analyze_postgres_json_plan, ['string'])).to be true + end + + it 'returns true for hash without Plan key' do + expect(gate.send(:analyze_postgres_json_plan, [{ 'Other' => 'data' }])).to be true + end + end + end + + describe '#analyze_postgres_text_plan' do + context 'with seq scan' do + let(:plan) { ['Seq Scan on users (cost=0.00..1000.00)'] } + + it 'returns false when seq scans not allowed' do + expect(gate.send(:analyze_postgres_text_plan, plan)).to be false + end + + it 'returns true when seq scans allowed' do + config.allow_seq_scans = true + expect(gate.send(:analyze_postgres_text_plan, plan)).to be true + end + end + + context 'with expensive sort' do + let(:plan) { ['Sort (cost=10000.00..10500.00)'] } + + it 'returns false' do + expect(gate.send(:analyze_postgres_text_plan, plan)).to be false + end + end + + context 'with expensive hash join' do + let(:plan) { ['Hash Join (cost=5000.00..15000.00)'] } + + it 'returns false' do + expect(gate.send(:analyze_postgres_text_plan, plan)).to be false + end + end + + context 'with safe plan' do + let(:plan) { ['Index Scan using users_pkey on users (cost=0.00..8.27)'] } + + it 'returns true' do + expect(gate.send(:analyze_postgres_text_plan, plan)).to be true + end + end + end + + describe '#analyze_mysql_plan' do + context 'with full table scan' do + let(:plan) { ['full table scan on users'] } + + it 'returns false' do + expect(gate.send(:analyze_mysql_plan, plan)).to be false + end + end + + context 'with filesort' do + let(:plan) { ['Using filesort'] } + + it 'returns false' do + expect(gate.send(:analyze_mysql_plan, plan)).to be false + end + end + + context 'with safe plan' do + let(:plan) { ['Using index'] } + + it 'returns true' do + expect(gate.send(:analyze_mysql_plan, plan)).to be true + end + end + end + + describe '#analyze_sqlite_plan' do + context 'with table scan' do + let(:plan) { ['SCAN TABLE users'] } + + it 'returns false' do + expect(gate.send(:analyze_sqlite_plan, plan)).to be false + end + end + + context 'with index scan' do + let(:plan) { ['SEARCH TABLE users USING INDEX users_email_idx'] } + + it 'returns true' do + expect(gate.send(:analyze_sqlite_plan, plan)).to be true + end + end + end + + describe '#analyze_generic_plan' do + context 'with full scan pattern' do + let(:plan) { ['FULL TABLE SCAN'] } + + it 'returns false' do + expect(gate.send(:analyze_generic_plan, plan)).to be false + end + end + + context 'with seq scan pattern' do + let(:plan) { ['SEQ SCAN on table'] } + + it 'returns false' do + expect(gate.send(:analyze_generic_plan, plan)).to be false + end + end + + context 'with high cost pattern' do + let(:plan) { ['cost=50000'] } + + it 'returns false' do + expect(gate.send(:analyze_generic_plan, plan)).to be false + end + end + + context 'with safe plan' do + let(:plan) { ['INDEX SCAN using primary key'] } + + it 'returns true' do + expect(gate.send(:analyze_generic_plan, plan)).to be true + end + end + end + + describe 'threshold configuration' do + it 'uses default max cost when not configured' do + config.max_query_cost = nil + node = { 'Node Type' => 'Seq Scan', 'Total Cost' => 15_000, 'Plan Rows' => 100 } + + expect(gate.send(:check_node_safety, node)).to be false + end + + it 'uses default max rows when not configured' do + config.max_query_rows = nil + node = { 'Node Type' => 'Seq Scan', 'Total Cost' => 100, 'Plan Rows' => 150_000 } + + expect(gate.send(:check_node_safety, node)).to be false + end + + it 'respects custom max cost' do + config.max_query_cost = 50_000 + node = { 'Node Type' => 'Hash Join', 'Total Cost' => 30_000, 'Plan Rows' => 100 } + + expect(gate.send(:check_node_safety, node)).to be true + end + + it 'respects custom max rows' do + config.max_query_rows = 500_000 + node = { 'Node Type' => 'Hash Join', 'Total Cost' => 100, 'Plan Rows' => 300_000 } + + expect(gate.send(:check_node_safety, node)).to be true + end + end +end diff --git a/spec/code_to_query/query_spec.rb b/spec/code_to_query/query_spec.rb index df8737d..7e3d490 100644 --- a/spec/code_to_query/query_spec.rb +++ b/spec/code_to_query/query_spec.rb @@ -34,7 +34,17 @@ describe '#safe?' do context 'with valid query' do it 'returns true for safe queries' do - expect(query.safe?).to be true + q = described_class.new( + sql: sql, + params: params, + bind_spec: bind_spec, + intent: intent, + allow_tables: ['users'], + config: config + ) + allow(q).to receive(:perform_safety_checks).and_return(true) + + expect(q.safe?).to be true end end @@ -47,14 +57,20 @@ end it 'caches the safety check result' do - linter = instance_double(CodeToQuery::Guardrails::SqlLinter) - allow(CodeToQuery::Guardrails::SqlLinter).to receive(:new).and_return(linter) - allow(linter).to receive(:check!).and_return(true) - - query.safe? - query.safe? - - expect(linter).to have_received(:check!).once + q = described_class.new( + sql: sql, + params: params, + bind_spec: bind_spec, + intent: intent, + allow_tables: ['users'], + config: config + ) + allow(q).to receive(:perform_safety_checks).and_return(true) + + q.safe? + q.safe? + + expect(q).to have_received(:perform_safety_checks).once end end @@ -67,8 +83,46 @@ end context 'with database adapter variations' do + let(:mock_connection) { double('Connection') } + let(:mock_result) { [{ 'QUERY PLAN' => 'Index Scan on users' }] } + + before do + ar_base = Class.new do + def self.connection + @mock_connection + end + + class << self + attr_writer :mock_connection + end + end + ar_base.mock_connection = mock_connection + stub_const('ActiveRecord::Base', ar_base) + allow(mock_connection).to receive(:execute).and_return(mock_result) + end + it 'uses PostgreSQL EXPLAIN format' do - skip 'ActiveRecord mocking too complex for unit tests' + config.adapter = :postgres + result = query.explain + + expect(mock_connection).to have_received(:execute).with( + "EXPLAIN (ANALYZE false, VERBOSE false, BUFFERS false) #{sql}" + ) + expect(result).to include('Index Scan on users') + end + + it 'uses MySQL EXPLAIN format' do + config.adapter = :mysql + query.explain + + expect(mock_connection).to have_received(:execute).with("EXPLAIN #{sql}") + end + + it 'uses SQLite EXPLAIN format' do + config.adapter = :sqlite + query.explain + + expect(mock_connection).to have_received(:execute).with("EXPLAIN QUERY PLAN #{sql}") end end end @@ -132,8 +186,26 @@ end context 'with ActiveRecord available' do - it 'builds QueryAttribute objects from bind_spec' do - skip 'ActiveRecord integration requires full Rails environment' + it 'returns binds based on bind_spec when mocked' do + mock_bind = double('QueryAttribute', name: 'active', value: true) + q = described_class.new( + sql: sql, + params: params, + bind_spec: bind_spec, + intent: intent, + allow_tables: ['users'], + config: config + ) + + # Stub the binds method to verify it returns expected structure + allow(q).to receive(:binds).and_return([mock_bind]) + + result = q.binds + + expect(result).to be_an(Array) + expect(result.length).to eq(1) + expect(result.first.name).to eq('active') + expect(result.first.value).to be(true) end end end diff --git a/spec/code_to_query/runner_spec.rb b/spec/code_to_query/runner_spec.rb new file mode 100644 index 0000000..8394723 --- /dev/null +++ b/spec/code_to_query/runner_spec.rb @@ -0,0 +1,265 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe CodeToQuery::Runner do + let(:config) { stub_config(adapter: :postgres, query_timeout: 30) } + let(:runner) { described_class.new(config) } + + describe '#run' do + context 'when ActiveRecord is not available' do + before do + allow(runner).to receive(:validate_execution_context!).and_raise( + CodeToQuery::ConnectionError, 'ActiveRecord not available or not connected' + ) + end + + it 'raises ConnectionError' do + expect { runner.run(sql: 'SELECT 1', binds: []) }.to raise_error( + CodeToQuery::ConnectionError, /ActiveRecord not available/ + ) + end + end + + context 'when query times out' do + before do + allow(runner).to receive(:validate_execution_context!) + allow(runner).to receive(:execute_with_timeout).and_raise( + CodeToQuery::ExecutionError, 'Query timed out after 30 seconds' + ) + end + + it 'raises ExecutionError with timeout message' do + expect { runner.run(sql: 'SELECT 1', binds: []) }.to raise_error( + CodeToQuery::ExecutionError, /timed out/ + ) + end + end + + context 'when query executes successfully' do + let(:mock_result) do + double('Result', columns: %w[id name], rows: [[1, 'Alice'], [2, 'Bob']]) + end + + before do + allow(runner).to receive(:validate_execution_context!) + allow(runner).to receive(:execute_with_timeout).and_return(mock_result) + end + + it 'returns the result' do + result = runner.run(sql: 'SELECT id, name FROM users', binds: []) + expect(result).to eq(mock_result) + end + end + + context 'when result exceeds MAX_ROWS_RETURNED' do + let(:large_rows) { (1..15_000).map { |i| [i, "User#{i}"] } } + let(:mock_result) do + double('Result', + columns: %w[id name], + rows: large_rows, + column_types: {}) + end + + before do + allow(runner).to receive(:validate_execution_context!) + allow(runner).to receive(:execute_with_timeout).and_return(mock_result) + stub_const('CodeToQuery::Runner::MAX_ROWS_RETURNED', 10_000) + end + + it 'truncates results to MAX_ROWS_RETURNED' do + result = runner.run(sql: 'SELECT * FROM users', binds: []) + + if result.respond_to?(:rows) + expect(result.rows.length).to be <= 10_000 + else + expect(result[:rows].length).to be <= 10_000 + expect(result[:truncated]).to be true + end + end + end + end + + describe '#validate_execution_context!' do + context 'when ActiveRecord is connected' do + before do + ar_base = Class.new do + def self.connected? + true + end + end + stub_const('ActiveRecord::Base', ar_base) + end + + it 'does not raise an error' do + expect { runner.send(:validate_execution_context!) }.not_to raise_error + end + end + end + + describe '#format_result' do + context 'with nil result' do + it 'returns a stub result' do + result = runner.send(:format_result, nil) + + if result.respond_to?(:columns) + expect(result.columns).to eq([]) + else + expect(result[:columns]).to eq([]) + end + end + end + + context 'with normal result' do + let(:mock_result) do + double('Result', columns: ['id'], rows: [[1], [2]]) + end + + before do + allow(mock_result).to receive(:respond_to?).with(:rows).and_return(true) + end + + it 'returns the result unchanged' do + result = runner.send(:format_result, mock_result) + expect(result).to eq(mock_result) + end + end + end + + describe '#handle_execution_error' do + let(:sql) { 'SELECT * FROM users WHERE id = 1' } + + context 'with Timeout::Error' do + it 'raises ExecutionError' do + error = Timeout::Error.new('execution expired') + expect { runner.send(:handle_execution_error, error, sql) }.to raise_error( + CodeToQuery::ExecutionError, /timed out/ + ) + end + end + + context 'with unexpected error' do + it 'raises ExecutionError with message' do + error = StandardError.new('something went wrong') + expect { runner.send(:handle_execution_error, error, sql) }.to raise_error( + CodeToQuery::ExecutionError, /Unexpected error/ + ) + end + end + + context 'with ConnectionError' do + it 're-raises the same error' do + error = CodeToQuery::ConnectionError.new('no connection') + expect { runner.send(:handle_execution_error, error, sql) }.to raise_error( + CodeToQuery::ConnectionError, 'no connection' + ) + end + end + + context 'with ExecutionError' do + it 're-raises the same error' do + error = CodeToQuery::ExecutionError.new('execution failed') + expect { runner.send(:handle_execution_error, error, sql) }.to raise_error( + CodeToQuery::ExecutionError, 'execution failed' + ) + end + end + end + + describe '#supports_readonly_role?' do + context 'when ActiveRecord supports connected_to' do + before do + stub_const('ActiveRecord', Module.new) + stub_const('ActiveRecord::Base', Class.new) + allow(ActiveRecord).to receive(:respond_to?).with(:connected_to).and_return(true) + allow(ActiveRecord::Base).to receive(:respond_to?).with(:connected_to).and_return(true) + end + + it 'returns true' do + expect(runner.send(:supports_readonly_role?)).to be true + end + end + end + + describe '#stub_result' do + it 'returns an empty result structure' do + result = runner.send(:stub_result) + + if result.respond_to?(:columns) + expect(result.columns).to eq([]) + expect(result.rows).to eq([]) + else + expect(result[:columns]).to eq([]) + expect(result[:rows]).to eq([]) + end + end + end + + describe '#set_session_readonly' do + let(:connection) { double('Connection') } + + context 'when force_readonly_session is false' do + before do + config.force_readonly_session = false + end + + it 'does not execute any SQL' do + expect { runner.send(:set_session_readonly, connection) }.not_to raise_error + end + end + + context 'when force_readonly_session is true with postgres' do + before do + config.force_readonly_session = true + config.adapter = :postgres + allow(connection).to receive(:execute) + end + + it 'sets session to readonly' do + runner.send(:set_session_readonly, connection) + expect(connection).to have_received(:execute).with('SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY') + end + end + + context 'when force_readonly_session is true with mysql' do + before do + config.force_readonly_session = true + config.adapter = :mysql + allow(connection).to receive(:execute) + end + + it 'sets session to readonly' do + runner.send(:set_session_readonly, connection) + expect(connection).to have_received(:execute).with('SET SESSION TRANSACTION READ ONLY') + end + end + end + + describe '#reset_session_readonly' do + let(:connection) { double('Connection') } + + context 'with postgres adapter' do + before do + config.adapter = :postgres + allow(connection).to receive(:execute) + end + + it 'resets session to read-write' do + runner.send(:reset_session_readonly, connection) + expect(connection).to have_received(:execute).with('SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE') + end + end + + context 'with mysql adapter' do + before do + config.adapter = :mysql + allow(connection).to receive(:execute) + end + + it 'resets session to read-write' do + runner.send(:reset_session_readonly, connection) + expect(connection).to have_received(:execute).with('SET SESSION TRANSACTION READ WRITE') + end + end + end +end diff --git a/spec/code_to_query/validator_spec.rb b/spec/code_to_query/validator_spec.rb new file mode 100644 index 0000000..717bc2e --- /dev/null +++ b/spec/code_to_query/validator_spec.rb @@ -0,0 +1,374 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe CodeToQuery::Validator do + let(:config) { stub_config(adapter: :postgres, default_limit: 100) } + let(:validator) { described_class.new } + + describe '#validate' do + context 'with valid basic intent' do + let(:intent) do + { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'] + } + end + + it 'returns validated intent with default limit' do + result = validator.validate(intent) + + expect(result[:type]).to eq('select') + expect(result[:table]).to eq('users') + expect(result[:columns]).to eq(['*']) + expect(result[:limit]).to eq(CodeToQuery.config.default_limit) + end + end + + context 'with missing required fields' do + it 'raises ArgumentError when type is missing' do + intent = { 'table' => 'users', 'columns' => ['*'] } + + expect { validator.validate(intent) }.to raise_error(ArgumentError, /type/) + end + + it 'raises ArgumentError when table is missing' do + intent = { 'type' => 'select', 'columns' => ['*'] } + + expect { validator.validate(intent) }.to raise_error(ArgumentError, /table/) + end + + it 'raises ArgumentError when columns is missing' do + intent = { 'type' => 'select', 'table' => 'users' } + + expect { validator.validate(intent) }.to raise_error(ArgumentError, /columns/) + end + end + + context 'with filters' do + it 'validates basic equality filter' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'id', 'op' => '=', 'param' => 'user_id' } + ] + } + + result = validator.validate(intent) + expect(result[:filters].first[:op]).to eq('=') + end + + it 'validates between filter with param_start and param_end' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { 'column' => 'created_at', 'op' => 'between', 'param_start' => 'start', 'param_end' => 'end' } + ] + } + + result = validator.validate(intent) + expect(result[:filters].first[:op]).to eq('between') + end + end + + context 'with exists filter' do + it 'validates exists filter with related_table and fk_column' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'filters' => [ + { + 'op' => 'exists', + 'related_table' => 'orders', + 'fk_column' => 'user_id', + 'related_filters' => [ + { 'column' => 'status', 'op' => '=', 'param' => 'order_status' } + ] + } + ] + } + + result = validator.validate(intent) + expect(result[:filters].first[:op]).to eq('exists') + expect(result[:filters].first[:related_table]).to eq('orders') + end + end + + context 'with not_exists filter' do + it 'validates not_exists filter' do + intent = { + 'type' => 'select', + 'table' => 'questions', + 'columns' => ['*'], + 'filters' => [ + { + 'op' => 'not_exists', + 'related_table' => 'answers', + 'fk_column' => 'question_id', + 'base_column' => 'id', + 'related_filters' => [ + { 'column' => 'student_id', 'op' => '=', 'param' => 'student' } + ] + } + ] + } + + result = validator.validate(intent) + expect(result[:filters].first[:op]).to eq('not_exists') + end + end + + context 'with order clause' do + it 'validates order clause' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'order' => [ + { 'column' => 'created_at', 'dir' => 'desc' } + ] + } + + result = validator.validate(intent) + expect(result[:order].first[:column]).to eq('created_at') + expect(result[:order].first[:dir]).to eq('desc') + end + + it 'validates multiple order columns' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'order' => [ + { 'column' => 'name', 'dir' => 'asc' }, + { 'column' => 'created_at', 'dir' => 'desc' } + ] + } + + result = validator.validate(intent) + expect(result[:order].length).to eq(2) + end + end + + context 'with distinct' do + it 'validates distinct flag' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['email'], + 'distinct' => true + } + + result = validator.validate(intent) + expect(result[:distinct]).to be true + end + + it 'validates distinct_on array' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'distinct' => true, + 'distinct_on' => ['user_id'] + } + + result = validator.validate(intent) + expect(result[:distinct_on]).to eq(['user_id']) + end + end + + context 'with aggregations' do + it 'validates aggregation with type and column' do + intent = { + 'type' => 'select', + 'table' => 'orders', + 'columns' => ['*'], + 'aggregations' => [ + { 'type' => 'sum', 'column' => 'amount' } + ] + } + + result = validator.validate(intent) + expect(result[:aggregations].first[:type]).to eq('sum') + end + + it 'validates count aggregation without column' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + 'aggregations' => [ + { 'type' => 'count' } + ] + } + + result = validator.validate(intent) + expect(result[:aggregations].first[:type]).to eq('count') + end + end + + context 'with group_by' do + it 'validates group_by columns' do + intent = { + 'type' => 'select', + 'table' => 'orders', + 'columns' => ['user_id'], + 'group_by' => ['user_id'] + } + + result = validator.validate(intent) + expect(result[:group_by]).to eq(['user_id']) + end + end + + context 'with metrics' do + it 'preserves _metrics from intent' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'], + '_metrics' => { 'prompt_tokens' => 100, 'elapsed_s' => 0.5 } + } + + result = validator.validate(intent) + expect(result['_metrics']).to eq({ 'prompt_tokens' => 100, 'elapsed_s' => 0.5 }) + end + end + + context 'with allow_tables restriction' do + it 'allows table when in allow_tables list' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'] + } + + result = validator.validate(intent, allow_tables: %w[users orders]) + expect(result[:table]).to eq('users') + end + + it 'performs case-insensitive table matching' do + intent = { + 'type' => 'select', + 'table' => 'Users', + 'columns' => ['*'] + } + + result = validator.validate(intent, allow_tables: ['users']) + expect(result[:table]).to eq('Users') + end + end + + context 'with policy adapter' do + let(:policy_adapter) do + lambda do |_user, **_kwargs| + { + allowed_tables: %w[users orders], + allowed_columns: { + 'users' => %w[id email name], + 'orders' => %w[id user_id total] + } + } + end + end + + before do + config.policy_adapter = policy_adapter + end + + it 'allows table permitted by policy' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => %w[id email] + } + + result = validator.validate(intent) + expect(result[:table]).to eq('users') + end + + it 'allows selecting wildcard column' do + intent = { + 'type' => 'select', + 'table' => 'users', + 'columns' => ['*'] + } + + result = validator.validate(intent) + expect(result[:columns]).to eq(['*']) + end + end + end + + describe '#preprocess_exists_filters' do + it 'adds default column for exists filters without column' do + intent = { + 'filters' => [ + { 'op' => 'exists', 'related_table' => 'orders', 'fk_column' => 'user_id' } + ] + } + + result = validator.send(:preprocess_exists_filters, intent) + expect(result['filters'].first['column']).to eq('id') + end + + it 'preserves existing column for exists filters' do + intent = { + 'filters' => [ + { 'op' => 'exists', 'column' => 'custom_id', 'related_table' => 'orders', 'fk_column' => 'user_id' } + ] + } + + result = validator.send(:preprocess_exists_filters, intent) + expect(result['filters'].first['column']).to eq('custom_id') + end + + it 'handles non-array filters gracefully' do + intent = { 'filters' => nil } + + result = validator.send(:preprocess_exists_filters, intent) + expect(result['filters']).to be_nil + end + end + + describe '#safe_call_policy_adapter' do + context 'when adapter accepts all arguments' do + let(:adapter) do + ->(user, table:, intent:) { { allowed_tables: ['users'] } } # rubocop:disable Lint/UnusedBlockArgument + end + + it 'calls adapter with full arguments' do + result = validator.send(:safe_call_policy_adapter, adapter, nil, table: 'users', intent: {}) + expect(result[:allowed_tables]).to eq(['users']) + end + end + + context 'when adapter only accepts user and table' do + let(:adapter) do + ->(user, table:) { { allowed_tables: ['orders'] } } # rubocop:disable Lint/UnusedBlockArgument + end + + it 'falls back to simpler call signature' do + result = validator.send(:safe_call_policy_adapter, adapter, nil, table: 'orders', intent: {}) + expect(result[:allowed_tables]).to eq(['orders']) + end + end + + context 'when adapter raises error' do + let(:adapter) do + ->(_user, **_kwargs) { raise StandardError, 'adapter error' } + end + + it 'returns empty hash' do + result = validator.send(:safe_call_policy_adapter, adapter, nil, table: 'users', intent: {}) + expect(result).to eq({}) + end + end + end +end