diff --git a/lib/active_record/connection_adapters/pinot_adapter.rb b/lib/active_record/connection_adapters/pinot_adapter.rb index 9762bdf..a975e98 100644 --- a/lib/active_record/connection_adapters/pinot_adapter.rb +++ b/lib/active_record/connection_adapters/pinot_adapter.rb @@ -13,6 +13,20 @@ def pinot_connection(config) module ConnectionAdapters class PinotAdapter < AbstractAdapter + NULL_REGEX = /^null$/i + SINGLE_QUOTED_STRING_REGEX = /^'([^|]*)'$/m + DOUBLE_QUOTED_STRING_REGEX = /^"([^|]*)"$/m + NUMERIC_REGEX = /\A-?\d+(\.\d*)?\z/ + BINARY_HEX_REGEX = /x'(.*)'/ + # Matches SQL functions or expressions (e.g., NOW(), CURRENT_DATE, or string concatenation) + SQL_FUNCTION_OR_EXPRESSION_REGEX = %r{ + \w+\(.*\) | # SQL functions like NOW(), uuid_generate_v4(), etc. + CURRENT_TIME | # special SQL keyword + CURRENT_DATE | # special SQL keyword + CURRENT_TIMESTAMP | # special SQL keyword + \|\| # SQL string concatenation operator + }x + TYPES = { "INT" => Type::Integer.new, "TIMESTAMP" => Type::DateTime.new, @@ -21,6 +35,7 @@ class PinotAdapter < AbstractAdapter "STRING" => Type::String.new, "JSON" => ActiveRecord::Type::Json.new } + def initialize(config = {}) @pinot_host = config.fetch(:host) @pinot_port = config.fetch(:port) @@ -84,24 +99,16 @@ def new_column_from_field(table_name, field, definitions = nil) def extract_value_from_default(default) case default - when /^null$/i - nil - # Quoted types - when /^'([^|]*)'$/m - $1.gsub("''", "'") - # Quoted types - when /^"([^|]*)"$/m - $1.gsub('""', '"') - # Numeric types - when /\A-?\d+(\.\d*)?\z/ - $& - # Binary columns - when /x'(.*)'/ - [$1].pack("H*") - else - # Anything else is blank or some function - # and we can't know the value of that, so return nil. + when NULL_REGEX nil + when SINGLE_QUOTED_STRING_REGEX + ::Regexp.last_match(1).gsub("''", "'") + when DOUBLE_QUOTED_STRING_REGEX + ::Regexp.last_match(1).gsub('""', '"') + when NUMERIC_REGEX + ::Regexp.last_match(0) + when BINARY_HEX_REGEX + [::Regexp.last_match(1)].pack('H*') end end @@ -110,7 +117,7 @@ def extract_default_function(default_value, default) end def has_default_function?(default_value, default) - !default_value && %r{\w+\(.*\)|CURRENT_TIME|CURRENT_DATE|CURRENT_TIMESTAMP|\|\|}.match?(default) + !default_value && SQL_FUNCTION_OR_EXPRESSION_REGEX.match?(default) end INTEGER_REGEX = /integer/i