@@ -555,3 +555,310 @@ func TestParser_JoinTreeLogic(t *testing.T) {
555555 }
556556 }
557557}
558+
559+ // TestParser_MultiColumnUSING tests multi-column USING clause support (Issue #70)
560+ func TestParser_MultiColumnUSING (t * testing.T ) {
561+ tests := []struct {
562+ name string
563+ sql string
564+ expectedColumns []string
565+ wantErr bool
566+ }{
567+ {
568+ name : "Single column USING (backward compatibility)" ,
569+ sql : "SELECT * FROM users JOIN orders USING (id)" ,
570+ expectedColumns : []string {"id" },
571+ wantErr : false ,
572+ },
573+ {
574+ name : "Two column USING" ,
575+ sql : "SELECT * FROM users JOIN orders USING (id, name)" ,
576+ expectedColumns : []string {"id" , "name" },
577+ wantErr : false ,
578+ },
579+ {
580+ name : "Three column USING" ,
581+ sql : "SELECT * FROM users JOIN orders USING (id, name, category)" ,
582+ expectedColumns : []string {"id" , "name" , "category" },
583+ wantErr : false ,
584+ },
585+ {
586+ name : "Multiple columns with LEFT JOIN" ,
587+ sql : "SELECT * FROM users LEFT JOIN orders USING (user_id, account_id)" ,
588+ expectedColumns : []string {"user_id" , "account_id" },
589+ wantErr : false ,
590+ },
591+ {
592+ name : "Multiple columns with INNER JOIN" ,
593+ sql : "SELECT * FROM products INNER JOIN categories USING (category_id, subcategory_id)" ,
594+ expectedColumns : []string {"category_id" , "subcategory_id" },
595+ wantErr : false ,
596+ },
597+ {
598+ name : "Four columns USING" ,
599+ sql : "SELECT * FROM table1 JOIN table2 USING (col1, col2, col3, col4)" ,
600+ expectedColumns : []string {"col1" , "col2" , "col3" , "col4" },
601+ wantErr : false ,
602+ },
603+ }
604+
605+ for _ , tt := range tests {
606+ t .Run (tt .name , func (t * testing.T ) {
607+ // Get tokenizer from pool
608+ tkz := tokenizer .GetTokenizer ()
609+ defer tokenizer .PutTokenizer (tkz )
610+
611+ // Tokenize SQL
612+ tokens , err := tkz .Tokenize ([]byte (tt .sql ))
613+ if err != nil {
614+ t .Fatalf ("Failed to tokenize: %v" , err )
615+ }
616+
617+ // Convert tokens for parser
618+ convertedTokens := convertTokens (tokens )
619+
620+ // Parse tokens
621+ parser := & Parser {}
622+ astObj , err := parser .Parse (convertedTokens )
623+ if (err != nil ) != tt .wantErr {
624+ t .Errorf ("Parse() error = %v, wantErr %v" , err , tt .wantErr )
625+ return
626+ }
627+
628+ if ! tt .wantErr && astObj != nil {
629+ defer ast .ReleaseAST (astObj )
630+
631+ // Verify we have a SELECT statement
632+ if len (astObj .Statements ) == 0 {
633+ t .Fatal ("No statements parsed" )
634+ }
635+
636+ selectStmt , ok := astObj .Statements [0 ].(* ast.SelectStatement )
637+ if ! ok {
638+ t .Fatal ("Expected SELECT statement" )
639+ }
640+
641+ // Verify we have a JOIN
642+ if len (selectStmt .Joins ) == 0 {
643+ t .Fatal ("Expected at least one JOIN" )
644+ }
645+
646+ join := selectStmt .Joins [0 ]
647+ if join .Condition == nil {
648+ t .Fatal ("Expected JOIN condition (USING clause)" )
649+ }
650+
651+ // Verify the columns
652+ if len (tt .expectedColumns ) == 1 {
653+ // Single column - should be stored as Identifier
654+ ident , ok := join .Condition .(* ast.Identifier )
655+ if ! ok {
656+ t .Fatalf ("Expected Identifier for single column USING, got %T" , join .Condition )
657+ }
658+ if ident .Name != tt .expectedColumns [0 ] {
659+ t .Errorf ("Expected column %s, got %s" , tt .expectedColumns [0 ], ident .Name )
660+ }
661+ } else {
662+ // Multiple columns - should be stored as ListExpression
663+ listExpr , ok := join .Condition .(* ast.ListExpression )
664+ if ! ok {
665+ t .Fatalf ("Expected ListExpression for multi-column USING, got %T" , join .Condition )
666+ }
667+
668+ if len (listExpr .Values ) != len (tt .expectedColumns ) {
669+ t .Fatalf ("Expected %d columns, got %d" , len (tt .expectedColumns ), len (listExpr .Values ))
670+ }
671+
672+ // Verify each column
673+ for i , expectedCol := range tt .expectedColumns {
674+ ident , ok := listExpr .Values [i ].(* ast.Identifier )
675+ if ! ok {
676+ t .Fatalf ("Column %d: expected Identifier, got %T" , i , listExpr .Values [i ])
677+ }
678+ if ident .Name != expectedCol {
679+ t .Errorf ("Column %d: expected %s, got %s" , i , expectedCol , ident .Name )
680+ }
681+ }
682+ }
683+ }
684+ })
685+ }
686+ }
687+
688+ // TestParser_MultiColumnUSINGEdgeCases tests edge cases for multi-column USING
689+ func TestParser_MultiColumnUSINGEdgeCases (t * testing.T ) {
690+ tests := []struct {
691+ name string
692+ sql string
693+ expectedError string
694+ wantErr bool
695+ }{
696+ {
697+ name : "Empty USING clause" ,
698+ sql : "SELECT * FROM users JOIN orders USING ()" ,
699+ expectedError : "expected column name in USING" ,
700+ wantErr : true ,
701+ },
702+ {
703+ name : "USING with trailing comma" ,
704+ sql : "SELECT * FROM users JOIN orders USING (id, name,)" ,
705+ expectedError : "expected column name in USING" ,
706+ wantErr : true ,
707+ },
708+ {
709+ name : "USING without closing parenthesis" ,
710+ sql : "SELECT * FROM users JOIN orders USING (id, name" ,
711+ expectedError : "expected ) after USING column list" ,
712+ wantErr : true ,
713+ },
714+ {
715+ name : "USING without opening parenthesis" ,
716+ sql : "SELECT * FROM users JOIN orders USING id, name)" ,
717+ expectedError : "expected ( after USING" ,
718+ wantErr : true ,
719+ },
720+ {
721+ name : "USING with non-identifier" ,
722+ sql : "SELECT * FROM users JOIN orders USING (id, 123)" ,
723+ expectedError : "expected column name in USING" ,
724+ wantErr : true ,
725+ },
726+ {
727+ name : "Multiple commas in USING" ,
728+ sql : "SELECT * FROM users JOIN orders USING (id,, name)" ,
729+ expectedError : "expected column name in USING" ,
730+ wantErr : true ,
731+ },
732+ }
733+
734+ for _ , tt := range tests {
735+ t .Run (tt .name , func (t * testing.T ) {
736+ // Get tokenizer from pool
737+ tkz := tokenizer .GetTokenizer ()
738+ defer tokenizer .PutTokenizer (tkz )
739+
740+ // Tokenize SQL
741+ tokens , err := tkz .Tokenize ([]byte (tt .sql ))
742+ if err != nil {
743+ // Some tests might fail at tokenization level
744+ if tt .wantErr {
745+ return // Expected failure
746+ }
747+ t .Fatalf ("Failed to tokenize: %v" , err )
748+ }
749+
750+ // Convert tokens for parser
751+ convertedTokens := convertTokens (tokens )
752+
753+ // Parse tokens
754+ parser := & Parser {}
755+ astObj , err := parser .Parse (convertedTokens )
756+
757+ if tt .wantErr {
758+ if err == nil {
759+ if astObj != nil {
760+ defer ast .ReleaseAST (astObj )
761+ }
762+ t .Errorf ("Expected error containing '%s', but got no error" , tt .expectedError )
763+ } else if ! containsError (err .Error (), tt .expectedError ) {
764+ t .Errorf ("Expected error containing '%s', got '%s'" , tt .expectedError , err .Error ())
765+ }
766+ } else {
767+ if err != nil {
768+ t .Errorf ("Unexpected error: %v" , err )
769+ }
770+ if astObj != nil {
771+ defer ast .ReleaseAST (astObj )
772+ }
773+ }
774+ })
775+ }
776+ }
777+
778+ // TestParser_MultiColumnUSINGWithComplexQueries tests multi-column USING in complex scenarios
779+ func TestParser_MultiColumnUSINGWithComplexQueries (t * testing.T ) {
780+ tests := []struct {
781+ name string
782+ sql string
783+ expectJoins int
784+ wantErr bool
785+ }{
786+ {
787+ name : "Multiple JOINs with multi-column USING" ,
788+ sql : `SELECT * FROM users
789+ JOIN orders USING (user_id, account_id)
790+ JOIN products USING (product_id, category_id)` ,
791+ expectJoins : 2 ,
792+ wantErr : false ,
793+ },
794+ {
795+ name : "Mixed ON and USING clauses" ,
796+ sql : `SELECT * FROM users u
797+ JOIN orders o USING (user_id, tenant_id)
798+ LEFT JOIN products p ON o.product_id = p.id` ,
799+ expectJoins : 2 ,
800+ wantErr : false ,
801+ },
802+ {
803+ name : "Multi-column USING with WHERE clause" ,
804+ sql : `SELECT * FROM users
805+ JOIN orders USING (user_id, account_id)
806+ WHERE users.active = true` ,
807+ expectJoins : 1 ,
808+ wantErr : false ,
809+ },
810+ {
811+ name : "Multi-column USING with ORDER BY and LIMIT" ,
812+ sql : `SELECT * FROM users
813+ JOIN orders USING (user_id, tenant_id)
814+ ORDER BY users.created_at DESC
815+ LIMIT 100` ,
816+ expectJoins : 1 ,
817+ wantErr : false ,
818+ },
819+ }
820+
821+ for _ , tt := range tests {
822+ t .Run (tt .name , func (t * testing.T ) {
823+ // Get tokenizer from pool
824+ tkz := tokenizer .GetTokenizer ()
825+ defer tokenizer .PutTokenizer (tkz )
826+
827+ // Tokenize SQL
828+ tokens , err := tkz .Tokenize ([]byte (tt .sql ))
829+ if err != nil {
830+ t .Fatalf ("Failed to tokenize: %v" , err )
831+ }
832+
833+ // Convert tokens for parser
834+ convertedTokens := convertTokens (tokens )
835+
836+ // Parse tokens
837+ parser := & Parser {}
838+ astObj , err := parser .Parse (convertedTokens )
839+ if (err != nil ) != tt .wantErr {
840+ t .Errorf ("Parse() error = %v, wantErr %v" , err , tt .wantErr )
841+ return
842+ }
843+
844+ if ! tt .wantErr && astObj != nil {
845+ defer ast .ReleaseAST (astObj )
846+
847+ // Verify we have a SELECT statement
848+ if len (astObj .Statements ) == 0 {
849+ t .Fatal ("No statements parsed" )
850+ }
851+
852+ selectStmt , ok := astObj .Statements [0 ].(* ast.SelectStatement )
853+ if ! ok {
854+ t .Fatal ("Expected SELECT statement" )
855+ }
856+
857+ // Verify JOIN count
858+ if len (selectStmt .Joins ) != tt .expectJoins {
859+ t .Errorf ("Expected %d JOINs, got %d" , tt .expectJoins , len (selectStmt .Joins ))
860+ }
861+ }
862+ })
863+ }
864+ }
0 commit comments