Skip to content

Commit c50e06c

Browse files
authored
Merge pull request #410 from mindsdb/fix-cast-decimal
Support CAST(a AS decimal(x, y))
2 parents fd1ae98 + ec6751b commit c50e06c

6 files changed

Lines changed: 35 additions & 11 deletions

File tree

mindsdb_sql/parser/ast/select/type_cast.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33

44

55
class TypeCast(ASTNode):
6-
def __init__(self, type_name, arg, length=None, *args, **kwargs):
6+
def __init__(self, type_name, arg, precision=None, *args, **kwargs):
77
super().__init__(*args, **kwargs)
88

99
self.type_name = type_name
1010
self.arg = arg
11-
self.length = length
11+
self.precision = precision
1212

1313
def to_tree(self, *args, level=0, **kwargs):
14-
out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, length={self.length}, arg=\n{indent(level+1)}{self.arg.to_tree()})'
14+
out_str = indent(level) + f'TypeCast(type_name={repr(self.type_name)}, precision={self.precision}, arg=\n{indent(level+1)}{self.arg.to_tree()})'
1515
return out_str
1616

1717
def get_string(self, *args, **kwargs):
1818
type_name = self.type_name
19-
if self.length is not None:
20-
type_name += f'({self.length})'
19+
if self.precision is not None:
20+
precision = map(str, self.precision)
21+
type_name += f'({",".join(precision)})'
2122
return f'CAST({str(self.arg)} AS {type_name})'

mindsdb_sql/parser/dialects/mindsdb/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1466,8 +1466,13 @@ def expr_list_or_nothing(self, p):
14661466
pass
14671467

14681468
@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
1469+
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
14691470
def expr(self, p):
1470-
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
1471+
if hasattr(p, 'integer'):
1472+
precision=[p.integer]
1473+
else:
1474+
precision=[p.integer0, p.integer1]
1475+
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)
14711476

14721477
@_('CAST LPAREN expr AS id RPAREN')
14731478
def expr(self, p):

mindsdb_sql/parser/dialects/mysql/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,13 @@ def expr_list_or_nothing(self, p):
821821
pass
822822

823823
@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
824+
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
824825
def expr(self, p):
825-
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
826+
if hasattr(p, 'integer'):
827+
precision=[p.integer]
828+
else:
829+
precision=[p.integer0, p.integer1]
830+
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)
826831

827832
@_('CAST LPAREN expr AS id RPAREN')
828833
def expr(self, p):

mindsdb_sql/parser/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,8 +581,13 @@ def expr_list_or_nothing(self, p):
581581
pass
582582

583583
@_('CAST LPAREN expr AS id LPAREN integer RPAREN RPAREN')
584+
@_('CAST LPAREN expr AS id LPAREN integer COMMA integer RPAREN RPAREN')
584585
def expr(self, p):
585-
return TypeCast(arg=p.expr, type_name=str(p.id), length=p.integer)
586+
if hasattr(p, 'integer'):
587+
precision=[p.integer]
588+
else:
589+
precision=[p.integer0, p.integer1]
590+
return TypeCast(arg=p.expr, type_name=str(p.id), precision=precision)
586591

587592
@_('CAST LPAREN expr AS id RPAREN')
588593
def expr(self, p):

mindsdb_sql/render/sqlalchemy_render.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def to_expression(self, t):
254254
elif isinstance(t, ast.TypeCast):
255255
arg = self.to_expression(t.arg)
256256
type = self.get_type(t.type_name)
257-
if t.length is not None:
258-
type = type(t.length)
257+
if t.precision is not None:
258+
type = type(*t.precision)
259259
col = sa.cast(arg, type)
260260

261261
if t.alias:

tests/test_parser/test_base_sql/test_select_structure.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,15 @@ def test_type_cast(self, dialect):
633633
sql = f"""SELECT CAST(a AS CHAR(10))"""
634634
ast = parse_sql(sql, dialect=dialect)
635635
expected_ast = Select(targets=[
636-
TypeCast(type_name='CHAR', arg=Identifier('a'), length=10)
636+
TypeCast(type_name='CHAR', arg=Identifier('a'), precision=[10])
637+
])
638+
assert ast.to_tree() == expected_ast.to_tree()
639+
assert str(ast) == str(expected_ast)
640+
641+
sql = f"""SELECT CAST(a AS DECIMAL(10, 1))"""
642+
ast = parse_sql(sql, dialect=dialect)
643+
expected_ast = Select(targets=[
644+
TypeCast(type_name='DECIMAL', arg=Identifier('a'), precision=[10, 1])
637645
])
638646
assert ast.to_tree() == expected_ast.to_tree()
639647
assert str(ast) == str(expected_ast)

0 commit comments

Comments
 (0)