-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathparse_tree.py
More file actions
109 lines (90 loc) · 3.64 KB
/
parse_tree.py
File metadata and controls
109 lines (90 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import sys
from antlr4 import *
from antlr4.error.ErrorListener import ErrorListener
from preprocessor.antlr_parser.pg_parser.PostgreSQLParser import PostgreSQLParser
from preprocessor.antlr_parser.pg_parser.PostgreSQLLexer import PostgreSQLLexer
from preprocessor.antlr_parser.mysql_parser.MySqlParser import MySqlParser
from preprocessor.antlr_parser.mysql_parser.MySqlLexer import MySqlLexer
from preprocessor.antlr_parser.oracle_parser.PlSqlParser import PlSqlParser
from preprocessor.antlr_parser.oracle_parser.PlSqlLexer import PlSqlLexer
from utils.constants import DIALECT_LIST
class CustomErrorListener(ErrorListener):
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise SelfParseError(line, column, msg)
def parse_tree(src_sql: str, dialect: str) -> (str, int, int, str):
if dialect == 'pg':
return parse_pg_tree(src_sql)
elif dialect == 'mysql':
return parse_mysql_tree(src_sql)
elif dialect == 'oracle':
return parse_oracle_tree(src_sql)
else:
raise ValueError("use one of" + str(DIALECT_LIST) + " as argument")
def parse_pg_tree(src_sql: str) -> (str, int, int, str):
try:
input_stream = InputStream(src_sql)
lexer = PostgreSQLLexer(input_stream)
lexer.addErrorListener(CustomErrorListener())
stream = CommonTokenStream(lexer)
parser = PostgreSQLParser(stream)
parser.addErrorListener(CustomErrorListener())
tree = parser.root()
return tree, None, None, None
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
raise e
def parse_mysql_tree(src_sql: str):
try:
input_stream = InputStream(src_sql)
lexer = MySqlLexer(input_stream)
lexer.addErrorListener(CustomErrorListener())
stream = CommonTokenStream(lexer)
parser = MySqlParser(stream)
parser.addErrorListener(CustomErrorListener())
tree = parser.root()
return tree, None, None, None
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
raise e
def parse_oracle_tree(src_sql: str):
try:
input_stream = InputStream(src_sql)
lexer = PlSqlLexer(input_stream)
lexer.addErrorListener(CustomErrorListener())
stream = CommonTokenStream(lexer)
parser = PlSqlParser(stream)
parser.addErrorListener(CustomErrorListener())
tree = parser.sql_script()
return tree, None, None, None
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
return None, -1, -1, ''
def get_parser(dialect: str):
input_stream = InputStream('')
if dialect == 'pg':
lexer = PostgreSQLLexer(input_stream)
stream = CommonTokenStream(lexer)
return PostgreSQLParser(stream)
elif dialect == 'mysql':
lexer = MySqlLexer(input_stream)
stream = CommonTokenStream(lexer)
return MySqlParser(stream)
elif dialect == 'oracle':
lexer = PlSqlLexer(input_stream)
stream = CommonTokenStream(lexer)
return PlSqlParser(stream)
else:
raise ValueError(f"Only support {DIALECT_LIST}")
class SelfParseError(Exception):
def __init__(self, line, column, msg):
super().__init__(f"Syntax error at line {line} , column {column} : {msg}")
self.line = line
self.column = column
self.msg = msg