Skip to content

Commit 2c9f1d9

Browse files
committed
feat: 优化SQL查询的行数限制实现方法与新增SQL详情展示功能
- 添加sqlglot工具类用于SQL方言来处理行数限制和数据导出的统计语句 - 新增SQL详情展示模态框,支持查看原始SQL和实际执行的SQL - 优化查询日志记录,增加original_sql字段(原语句) - 改进SQL处理逻辑,支持CTE语法(WITH)查询 - 修复测试用例以匹配新功能
1 parent 10f8ce5 commit 2c9f1d9

14 files changed

Lines changed: 544 additions & 52 deletions

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ boto3
5151
azure_storage_blob==12.26.0
5252
openpyxl==3.1.5
5353
parameterized
54-
54+
sqlglot

sql/engines/goinception.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def query_data_masking(self, instance, db_name=None, sql=""):
203203
sql = f"""/*--user={user};--password={password};--host={host};--port={port};--masking=1;*/
204204
inception_magic_start;
205205
use `{db_name}`;
206-
{sql}
206+
{sql};
207207
inception_magic_commit;"""
208208
query_result = self.query(db_name=db_name, sql=sql)
209209
# 有异常时主动抛出
@@ -218,7 +218,7 @@ def query_data_masking(self, instance, db_name=None, sql=""):
218218
if print_info.get("errlevel") == 0 and print_info.get("errmsg") is None:
219219
return json.loads(print_info["query_tree"])
220220
else:
221-
raise RuntimeError(f'Inception Error: print_info.get("errmsg")')
221+
raise RuntimeError(f'Inception Error: {print_info.get("errmsg")}')
222222

223223
def get_rollback(self, workflow):
224224
"""

sql/engines/mssql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ def query_check(self, db_name=None, sql=""):
282282
]
283283
keyword_warning = ""
284284
star_patter = r"(^|,|\s)\*(\s|\(|$)"
285-
sql_whitelist = ["select", "sp_helptext"]
285+
sql_whitelist = ["select", "sp_helptext", "with"]
286286
# 根据白名单list拼接pattern语句
287287
whitelist_pattern = "^" + "|^".join(sql_whitelist)
288288
# 删除注释语句,进行语法判断,执行第一条有效sql
289289
try:
290-
sql = sql.format(sql, strip_comments=True)
290+
sql = sqlparse.format(sql, strip_comments=True)
291291
sql = sqlparse.split(sql)[0]
292292
result["filtered_sql"] = sql.strip()
293293
sql_lower = sql.lower()

sql/engines/mysql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,14 +559,14 @@ def query_check(self, db_name=None, sql=""):
559559
except IndexError:
560560
result["bad_query"] = True
561561
result["msg"] = "没有有效的SQL语句"
562-
if re.match(r"^select|^show|^explain", sql, re.I) is None:
562+
if re.match(r"^select|^show|^explain|^with", sql, re.I) is None:
563563
result["bad_query"] = True
564564
result["msg"] = "不支持的查询语法类型!"
565565
if "*" in sql:
566566
result["has_star"] = True
567567
result["msg"] = "SQL语句中含有 * "
568-
# select语句先使用Explain判断语法是否正确
569-
if re.match(r"^select", sql, re.I):
568+
# select和with语句先使用Explain判断语法是否正确
569+
if re.match(r"^select|^with", sql, re.I):
570570
explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
571571
if explain_result.error:
572572
result["bad_query"] = True

sql/engines/odps.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,9 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
105105
"""返回 ResultSet"""
106106
result_set = ResultSet(full_sql=sql)
107107

108-
if not re.match(r"^select", sql, re.I):
108+
if not re.match(r"^select|^with", sql, re.I):
109109
result_set.error = str("仅支持ODPS查询语句")
110110

111-
# 存在limit,替换limit; 不存在,添加limit
112-
if re.search("limit", sql):
113-
sql = re.sub("limit.+(\d+)", "limit " + str(limit_num), sql)
114-
else:
115-
if sql.strip()[-1] == ";":
116-
sql = sql[:-1]
117-
sql = sql + " limit " + str(limit_num) + ";"
118-
119111
try:
120112
conn = self.get_connection(db_name)
121113
effect_row = conn.execute_sql(sql)
@@ -136,7 +128,7 @@ def query_check(self, db_name=None, sql=""):
136128
# 查询语句的检查、注释去除、切分
137129
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
138130
keyword_warning = ""
139-
sql_whitelist = ["select"]
131+
sql_whitelist = ["select", "with"]
140132
# 根据白名单list拼接pattern语句
141133
whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE)
142134
# 删除注释语句,进行语法判断,执行第一条有效sql

sql/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class QueryLog(models.Model):
618618
instance_name = models.CharField("实例名称", max_length=50)
619619
db_name = models.CharField("数据库名称", max_length=64)
620620
sqllog = models.TextField("执行的查询语句")
621+
original_sql = models.TextField("原始查询语句")
621622
effect_row = models.BigIntegerField("返回行数")
622623
cost_time = models.CharField("执行耗时", max_length=10, default="")
623624
# TODO 改为user 外键

sql/offlinedownload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import datetime
99
import xml.etree.ElementTree as ET
1010
import zipfile
11+
from numpy import diag
1112
import sqlparse
1213
import time
1314

@@ -22,6 +23,7 @@
2223
from sql.storage import DynamicStorage
2324
from sql.engines import get_engine
2425
from common.config import SysConfig
26+
from sql.utils.sql_utils import SqlglotUtils
2527

2628
logger = logging.getLogger("default")
2729

@@ -134,7 +136,8 @@ def pre_count_check(self, workflow):
134136
full_sql = sqlparse.format(full_sql, strip_comments=True)
135137
full_sql = sqlparse.split(full_sql)[0]
136138
sql = full_sql.strip()
137-
count_sql = f"SELECT COUNT(*) FROM ({sql.rstrip(';')}) t"
139+
dialect = SqlglotUtils.get_dialect(workflow.db_type)
140+
count_sql = SqlglotUtils.wrap_query_with_count(sql, dialect)
138141
clean_sql = sql.strip().lower()
139142
instance = workflow
140143
check_result = ReviewSet(full_sql=sql)

sql/query.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sql.utils.tasks import add_kill_conn_schedule, del_schedule
2020
from .models import QueryLog, Instance
2121
from sql.engines import get_engine
22+
from sql.utils.sql_utils import SqlglotUtils
2223

2324
logger = logging.getLogger("default")
2425

@@ -36,6 +37,7 @@ def query(request):
3637
tb_name = request.POST.get("tb_name")
3738
limit_num = int(request.POST.get("limit_num", 0))
3839
schema_name = request.POST.get("schema_name", None)
40+
is_offline_export = int(request.POST.get("is_offline_export", 0))
3941
user = request.user
4042

4143
result = {"status": 0, "msg": "ok", "data": {}}
@@ -68,6 +70,7 @@ def query(request):
6870
result["msg"] = query_check_info.get("msg")
6971
return HttpResponse(json.dumps(result), content_type="application/json")
7072
sql_content = query_check_info["filtered_sql"]
73+
original_sql = sql_content.strip()
7174

7275
# 查询权限校验,并且获取limit_num
7376
priv_check_info = query_priv_check(
@@ -82,9 +85,22 @@ def query(request):
8285
return HttpResponse(json.dumps(result), content_type="application/json")
8386
# explain的limit_num设置为0
8487
limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num
85-
86-
# 对查询sql增加limit限制或者改写语句
87-
sql_content = query_engine.filter_sql(sql=sql_content, limit_num=limit_num)
88+
dialect = SqlglotUtils.get_dialect(instance.db_type)
89+
if is_offline_export:
90+
# 离线导出,统计总数,需要包装count查询语句
91+
sql_content = SqlglotUtils.wrap_query_with_count(sql_content, dialect)
92+
else:
93+
# 页面查询,增加行数限制
94+
# 支持sqlglot方言转换的,使用方言添加行数限制
95+
if dialect:
96+
sql_content = SqlglotUtils.add_limit_to_query(
97+
sql_content, limit_num, dialect
98+
)
99+
else:
100+
# 不支持sqlglot方言的,使用引擎filter_sql函数处理
101+
sql_content = query_engine.filter_sql(
102+
sql=sql_content, limit_num=limit_num
103+
)
88104

89105
# 先获取查询连接,用于后面查询复用连接以及终止会话
90106
query_engine.get_connection(db_name=db_name)
@@ -176,6 +192,7 @@ def query(request):
176192
db_name=db_name,
177193
instance_name=instance.instance_name,
178194
sqllog=sql_content,
195+
original_sql=original_sql,
179196
effect_row=limit_num,
180197
cost_time=query_result.query_time,
181198
priv_check=priv_check,
@@ -273,6 +290,7 @@ def _querylog(request):
273290
"instance_name",
274291
"db_name",
275292
"sqllog",
293+
"original_sql",
276294
"effect_row",
277295
"cost_time",
278296
"user_display",

sql/templates/sqlexportsubmit.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@
527527
}
528528

529529
var sqlContent = sqlContent.trim().replace(/;$/, '');
530-
var sqlContent = 'select count(1) from (' + sqlContent + '\n) t'
531530

532531
//提交请求
533532
$.ajax({
@@ -540,6 +539,7 @@
540539
schema_name: $("#schema_name").val(),
541540
tb_name: $("#table_name").val(),
542541
sql_content: sqlContent,
542+
is_offline_export: 1,
543543
limit_num: $("#limit_num").val()
544544
},
545545
complete: function () {

0 commit comments

Comments
 (0)