@@ -75,13 +75,15 @@ fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, Sy
7575 . and_then ( |x| x. segment ( ) )
7676 . and_then ( |x| x. name_ref ( ) )
7777 {
78- return name_from_name_ref ( name_ref, true ) . map ( |( column, node) | {
79- let column = match column {
80- ColumnName :: Column ( c) => ColumnName :: new ( c, unknown_column) ,
81- _ => column,
82- } ;
83- ( column, node)
84- } ) ;
78+ return name_from_name_ref ( name_ref, true , path_type. arg_list ( ) . as_ref ( ) ) . map (
79+ |( column, node) | {
80+ let column = match column {
81+ ColumnName :: Column ( c) => ColumnName :: new ( c, unknown_column) ,
82+ _ => column,
83+ } ;
84+ ( column, node)
85+ } ,
86+ ) ;
8587 }
8688 }
8789 ast:: Type :: BitType ( bit_type) => {
@@ -158,7 +160,11 @@ fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, Sy
158160 None
159161}
160162
161- fn name_from_name_ref ( name_ref : ast:: NameRef , in_type : bool ) -> Option < ( ColumnName , SyntaxNode ) > {
163+ fn name_from_name_ref (
164+ name_ref : ast:: NameRef ,
165+ in_type : bool ,
166+ arg_list : Option < & ast:: ArgList > ,
167+ ) -> Option < ( ColumnName , SyntaxNode ) > {
162168 if in_type {
163169 for node in name_ref. syntax ( ) . children_with_tokens ( ) {
164170 match node. kind ( ) {
@@ -174,12 +180,32 @@ fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnNa
174180 name_ref. syntax ( ) . clone ( ) ,
175181 ) ) ;
176182 }
177- SyntaxKind :: DECIMAL_KW => {
183+ SyntaxKind :: DEC_KW | SyntaxKind :: DECIMAL_KW => {
178184 return Some ( (
179185 ColumnName :: Column ( "numeric" . to_owned ( ) ) ,
180186 name_ref. syntax ( ) . clone ( ) ,
181187 ) ) ;
182188 }
189+ SyntaxKind :: FLOAT_KW => {
190+ let precision = arg_list. and_then ( |arg| {
191+ arg. args_ ( ) . find_map ( |arg| {
192+ if let ast:: Expr :: Literal ( lit) = arg. expr ( ) ? {
193+ lit. syntax ( ) . text ( ) . to_string ( ) . parse :: < u32 > ( ) . ok ( )
194+ } else {
195+ None
196+ }
197+ } )
198+ } ) ;
199+ let name = if matches ! ( precision, Some ( p) if p <= 24 ) {
200+ "float4"
201+ } else {
202+ "float8"
203+ } ;
204+ return Some ( (
205+ ColumnName :: Column ( name. to_owned ( ) ) ,
206+ name_ref. syntax ( ) . clone ( ) ,
207+ ) ) ;
208+ }
183209 SyntaxKind :: INT_KW | SyntaxKind :: INTEGER_KW => {
184210 return Some ( (
185211 ColumnName :: Column ( "int4" . to_owned ( ) ) ,
@@ -409,11 +435,11 @@ fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxN
409435 | ast:: Expr :: SliceExpr ( _) => unreachable ! ( "not possible in the grammar" ) ,
410436 ast:: Expr :: FieldExpr ( field_expr) => {
411437 if let Some ( name_ref) = field_expr. field ( ) {
412- return name_from_name_ref ( name_ref, in_type) ;
438+ return name_from_name_ref ( name_ref, in_type, None ) ;
413439 }
414440 }
415441 ast:: Expr :: NameRef ( name_ref) => {
416- return name_from_name_ref ( name_ref, in_type) ;
442+ return name_from_name_ref ( name_ref, in_type, None ) ;
417443 }
418444 }
419445 }
@@ -446,7 +472,7 @@ fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxN
446472 }
447473 ast:: Expr :: FieldExpr ( field_expr) => {
448474 if let Some ( name_ref) = field_expr. field ( ) {
449- return name_from_name_ref ( name_ref, in_type) ;
475+ return name_from_name_ref ( name_ref, in_type, None ) ;
450476 }
451477 }
452478 ast:: Expr :: IndexExpr ( index_expr) => {
@@ -472,7 +498,7 @@ fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxN
472498 _ => return Some ( ( ColumnName :: UnknownColumn ( None ) , node) ) ,
473499 } ,
474500 ast:: Expr :: NameRef ( name_ref) => {
475- return name_from_name_ref ( name_ref, in_type) ;
501+ return name_from_name_ref ( name_ref, in_type, None ) ;
476502 }
477503 ast:: Expr :: ParenExpr ( paren_expr) => {
478504 if let Some ( expr) = paren_expr. expr ( ) {
@@ -614,20 +640,32 @@ fn examples() {
614640
615641 // bit types
616642 assert_snapshot ! ( name( "cast('1010' as bit varying(10))" ) , @"varbit" ) ;
643+ assert_snapshot ! ( name( "cast('1010' as bit varying)" ) , @"varbit" ) ;
644+ assert_snapshot ! ( name( "cast('1010' as bit)" ) , @"bit" ) ;
645+
646+ // decimal
647+ assert_snapshot ! ( name( "cast('1010' as dec)" ) , @"numeric" ) ;
648+ assert_snapshot ! ( name( "cast('1010' as dec(10))" ) , @"numeric" ) ;
649+ assert_snapshot ! ( name( "cast('1010' as decimal)" ) , @"numeric" ) ;
650+ assert_snapshot ! ( name( "cast('1010' as decimal(10))" ) , @"numeric" ) ;
617651
618652 // char types
619653 assert_snapshot ! ( name( "cast('hello' as character varying(10))" ) , @"varchar" ) ;
620654 assert_snapshot ! ( name( "cast('hello' as char varying(5))" ) , @"varchar" ) ;
655+ assert_snapshot ! ( name( "cast('hello' as nchar varying(10))" ) , @"varchar" ) ;
621656 assert_snapshot ! ( name( "cast('hello' as char(5))" ) , @"bpchar" ) ;
622657 assert_snapshot ! ( name( "cast('hello' as character)" ) , @"bpchar" ) ;
623658 assert_snapshot ! ( name( "cast('hello' as bpchar)" ) , @"bpchar" ) ;
659+ assert_snapshot ! ( name( "cast('hello' as nchar(10))" ) , @"bpchar" ) ;
624660
625661 assert_snapshot ! ( name( r#"cast('hello' as "char")"# ) , @"char" ) ;
626662
627663 // double types
628664 assert_snapshot ! ( name( "cast(1.5 as double precision)" ) , @"float8" ) ;
629665 // real
630666 assert_snapshot ! ( name( "cast(1.5 as real)" ) , @"float4" ) ;
667+ assert_snapshot ! ( name( "cast(1.5 as float(8))" ) , @"float4" ) ;
668+ assert_snapshot ! ( name( "cast(2.5 as float(25))" ) , @"float8" ) ;
631669
632670 // interval types
633671 assert_snapshot ! ( name( "cast('1 hour' as interval hour to minute)" ) , @"interval" ) ;
0 commit comments