Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,15 @@
"#export\n",
"def preformat_statements(s):\n",
" \"\"\"Write a newline in `s` for all `statements` and\n",
" uppercase them but not if they are inside a comment\"\"\"\n",
" uppercase them but not if they are inside a comment or quoted text\"\"\"\n",
" statements = MAIN_STATEMENTS\n",
" s = clean_query(s) # clean query and mark comments\n",
" split_s = split_query(s) # split by comment and non comment\n",
" split_s = compress_dicts(split_s, [\"comment\", \"select\"])\n",
" # compile regex before loop\n",
" create_re = re.compile(r\"\\bcreate\\b\", flags=re.I)\n",
" select_re = re.compile(r\"\\bselect\\b\", flags=re.I)\n",
" quote_split = False\n",
" for statement in statements:\n",
" if create_re.match(statement): # special case CREATE with AS capitalize as well\n",
" create_sub = re.compile(rf\"\\s*({statement} )(.*) as\\b\", flags=re.I)\n",
Expand All @@ -470,17 +471,22 @@
" \"select\": sdict[\"select\"]\n",
" } for sdict in split_s]\n",
" else: # normal main statements\n",
" if not quote_split:\n",
" split_s = split_query(\"\".join([sdict[\"string\"] for sdict in split_s]))\n",
" split_s = compress_dicts(split_s, [\"comment\", \"quote\", \"select\"])\n",
" quote_split = True\n",
" non_select_region_re = re.compile(rf\"\\s*\\b({statement})\\b\", flags=re.I)\n",
" select_region_statement_re = re.compile(rf\"\\b({statement})\\b\", flags=re.I)\n",
" split_s = [{\n",
" \"string\": non_select_region_re.sub(\"\\n\" + statement.upper(), sdict[\"string\"]) \n",
" if not sdict[\"comment\"] and not sdict[\"select\"] # no comment, no select region\n",
" if not sdict[\"comment\"] and not sdict[\"quote\"] and not sdict[\"select\"] # no comment, no quote, no select region\n",
" else non_select_region_re.sub(\"\\n\" + statement.upper(), sdict[\"string\"]) \n",
" if not sdict[\"comment\"] and sdict[\"select\"] and select_re.match(statement) # no comment, select region and select statement\n",
" if not sdict[\"comment\"] and not sdict[\"quote\"] and sdict[\"select\"] and select_re.match(statement) # no comment, no quote, select region and select statement\n",
" else select_region_statement_re.sub(statement.upper(), sdict[\"string\"]) \n",
" if not sdict[\"comment\"] and sdict[\"select\"] and not select_re.match(statement) # no comment, select region and no select statement\n",
" if not sdict[\"comment\"] and not sdict[\"quote\"] and sdict[\"select\"] and not select_re.match(statement) # no comment, no quote, select region and no select statement\n",
" else sdict[\"string\"],\n",
" \"comment\": sdict[\"comment\"],\n",
" \"quote\": sdict[\"quote\"],\n",
" \"select\": sdict[\"select\"]\n",
" } for sdict in split_s]\n",
" s = \"\".join([sdict[\"string\"] for sdict in split_s])\n",
Expand Down
16 changes: 16 additions & 0 deletions nbs/99_additional_tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert_and_print(\n",
" format_sql(\"select * from rooms where command = 'on'\"),\n",
"\"\"\"\n",
"SELECT *\n",
"FROM rooms\n",
"WHERE command = 'on'\n",
"\"\"\".strip()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
16 changes: 11 additions & 5 deletions sql_formatter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def clean_query(s):
# Cell
def preformat_statements(s):
"""Write a newline in `s` for all `statements` and
uppercase them but not if they are inside a comment"""
uppercase them but not if they are inside a comment or quoted text"""
statements = MAIN_STATEMENTS
s = clean_query(s) # clean query and mark comments
split_s = split_query(s) # split by comment and non comment
split_s = compress_dicts(split_s, ["comment", "select"])
# compile regex before loop
create_re = re.compile(r"\bcreate\b", flags=re.I)
select_re = re.compile(r"\bselect\b", flags=re.I)
quote_split = False
for statement in statements:
if create_re.match(statement): # special case CREATE with AS capitalize as well
create_sub = re.compile(rf"\s*({statement} )(.*) as\b", flags=re.I)
Expand All @@ -65,17 +66,22 @@ def preformat_statements(s):
"select": sdict["select"]
} for sdict in split_s]
else: # normal main statements
if not quote_split:
split_s = split_query("".join([sdict["string"] for sdict in split_s]))
split_s = compress_dicts(split_s, ["comment", "quote", "select"])
quote_split = True
non_select_region_re = re.compile(rf"\s*\b({statement})\b", flags=re.I)
select_region_statement_re = re.compile(rf"\b({statement})\b", flags=re.I)
split_s = [{
"string": non_select_region_re.sub("\n" + statement.upper(), sdict["string"])
if not sdict["comment"] and not sdict["select"] # no comment, no select region
if not sdict["comment"] and not sdict["quote"] and not sdict["select"] # no comment, no quote, no select region
else non_select_region_re.sub("\n" + statement.upper(), sdict["string"])
if not sdict["comment"] and sdict["select"] and select_re.match(statement) # no comment, select region and select statement
if not sdict["comment"] and not sdict["quote"] and sdict["select"] and select_re.match(statement) # no comment, no quote, select region and select statement
else select_region_statement_re.sub(statement.upper(), sdict["string"])
if not sdict["comment"] and sdict["select"] and not select_re.match(statement) # no comment, select region and no select statement
if not sdict["comment"] and not sdict["quote"] and sdict["select"] and not select_re.match(statement) # no comment, no quote, select region and no select statement
else sdict["string"],
"comment": sdict["comment"],
"quote": sdict["quote"],
"select": sdict["select"]
} for sdict in split_s]
s = "".join([sdict["string"] for sdict in split_s])
Expand Down Expand Up @@ -482,4 +488,4 @@ def format_sql(s, semicolon=False, max_len=82):
subquery_pos = extract_outer_subquery(s)
# remove whitespace between word and parenthesis
s = re.sub(r"\s*\)", ")", s)
return s
return s