Skip to content

Commit 57f2d6d

Browse files
committed
chore: add a formatting tool
1 parent 63396a5 commit 57f2d6d

1 file changed

Lines changed: 204 additions & 0 deletions

File tree

scripts/format.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import argparse
2+
import subprocess
3+
import os
4+
from pathlib import Path
5+
from colorama import Fore, Style
6+
7+
# Supported file types and their corresponding formatter categories
8+
SUPPORTED_FILES = {
9+
".h": "c",
10+
".hh": "c",
11+
".hpp": "c",
12+
".c": "c",
13+
".cc": "c",
14+
".cpp": "c",
15+
".cxx": "c",
16+
".cu": "c",
17+
".cuh": "c",
18+
".mlu": "c",
19+
".cl": "c",
20+
".py": "py",
21+
}
22+
23+
24+
def format_file(file: Path, check: bool, formatter) -> bool:
25+
formatter = formatter.get(SUPPORTED_FILES.get(file.suffix, None), None)
26+
if not formatter:
27+
return True # Unsupported file type, skip
28+
29+
try:
30+
cmd = []
31+
if formatter.startswith("clang-format"):
32+
cmd = [formatter, "-style=file", "-i", file]
33+
if check:
34+
cmd.insert(2, "-dry-run")
35+
process = subprocess.run(
36+
cmd,
37+
capture_output=True,
38+
text=True,
39+
check=True,
40+
)
41+
if process.stderr:
42+
print(f"{Fore.YELLOW}{file} is not formatted.{Style.RESET_ALL}")
43+
print(
44+
f"Use {Fore.CYAN}{formatter} -style=file -i {file}{Style.RESET_ALL} to format it."
45+
)
46+
return False
47+
else:
48+
subprocess.run(
49+
cmd,
50+
capture_output=True,
51+
text=True,
52+
check=True,
53+
)
54+
print(f"{Fore.CYAN}Formatted: {file}{Style.RESET_ALL}")
55+
elif formatter == "black":
56+
cmd = [formatter, file]
57+
if check:
58+
cmd.insert(1, "--check")
59+
process = subprocess.run(
60+
cmd,
61+
capture_output=True,
62+
text=True,
63+
check=True,
64+
)
65+
if process.returncode != 0:
66+
print(f"{Fore.YELLOW}{file} is not formatted.{Style.RESET_ALL}")
67+
print(
68+
f"Use {Fore.CYAN}{formatter} {file}{Style.RESET_ALL} to format it."
69+
)
70+
return False
71+
else:
72+
subprocess.run(
73+
cmd,
74+
capture_output=True,
75+
text=True,
76+
check=True,
77+
)
78+
print(f"{Fore.CYAN}Formatted: {file}{Style.RESET_ALL}")
79+
except FileNotFoundError:
80+
print(
81+
f"{Fore.RED}Formatter {formatter} not found, {file} skipped.{Style.RESET_ALL}"
82+
)
83+
except subprocess.CalledProcessError as e:
84+
print(f"{Fore.RED}Formatter {formatter} failed: {e}{Style.RESET_ALL}")
85+
86+
return True
87+
88+
89+
def git_added_files():
90+
"""Get all staged files"""
91+
try:
92+
# Use git diff --cached --name-only to get all files added to staging area
93+
result = subprocess.run(
94+
["git", "diff", "--cached", "--diff-filter=AMR", "--name-only"],
95+
capture_output=True,
96+
text=True,
97+
check=True,
98+
)
99+
for file in result.stdout.splitlines():
100+
yield Path(file.strip())
101+
except subprocess.CalledProcessError as e:
102+
print(f"{Fore.RED}Git diff failed: {e}{Style.RESET_ALL}")
103+
104+
105+
def git_modified_since_ref(ref):
106+
"""Get list of files modified from the specified Git reference to the current state"""
107+
try:
108+
result = subprocess.run(
109+
["git", "diff", f"{ref}..", "--diff-filter=AMR", "--name-only"],
110+
capture_output=True,
111+
text=True,
112+
check=True,
113+
)
114+
for file in result.stdout.splitlines():
115+
yield Path(file.strip())
116+
except subprocess.CalledProcessError as e:
117+
print(f"{Fore.RED}Git diff failed: {e}{Style.RESET_ALL}")
118+
119+
120+
def list_files(paths):
121+
"""Recursively get all files under the specified paths"""
122+
files = []
123+
for path in paths:
124+
if path.is_file():
125+
yield path
126+
elif path.is_dir():
127+
for dirpath, _, filenames in os.walk(path):
128+
for name in filenames:
129+
yield Path(dirpath) / name
130+
else:
131+
print(
132+
f"{Fore.RED}Error: {path} is not a file or directory.{Style.RESET_ALL}"
133+
)
134+
135+
136+
def filter_in_path(file: Path, path) -> bool:
137+
"""Check if file is within the specified paths"""
138+
for p in path:
139+
if file.is_relative_to(p):
140+
return True
141+
return False
142+
143+
144+
def main():
145+
parser = argparse.ArgumentParser()
146+
parser.add_argument(
147+
"--ref", type=str, help="Git reference (commit hash) to compare against."
148+
)
149+
parser.add_argument(
150+
"--path", nargs="*", type=Path, help="Files to format or check."
151+
)
152+
parser.add_argument(
153+
"--check", action="store_true", help="Check files without modifying them."
154+
)
155+
parser.add_argument(
156+
"--c", default="clang-format-16", help="C formatter (default: clang-format-16)"
157+
)
158+
parser.add_argument(
159+
"--py", default="black", help="Python formatter (default: black)"
160+
)
161+
args = parser.parse_args()
162+
163+
if args.ref is None and args.path is None:
164+
# Last commit.
165+
print(f"{Fore.GREEN}Formatting git staged files.{Style.RESET_ALL}")
166+
files = git_added_files()
167+
168+
else:
169+
if args.ref is None:
170+
print(f"{Fore.GREEN}Formatting files in {args.path}.{Style.RESET_ALL}")
171+
files = list_files(args.path)
172+
elif args.path is None:
173+
print(
174+
f"{Fore.GREEN}Formatting git modified files from {args.ref}.{Style.RESET_ALL}"
175+
)
176+
files = git_modified_since_ref(args.ref)
177+
else:
178+
print(
179+
f"{Fore.GREEN}Formatting git modified files from {args.ref} in {args.path}.{Style.RESET_ALL}"
180+
)
181+
files = (
182+
file
183+
for file in git_modified_since_ref(args.ref)
184+
if filter_in_path(file, args.path)
185+
)
186+
187+
formatted = True
188+
for file in files:
189+
if not format_file(
190+
file,
191+
args.check,
192+
{
193+
"c": args.c,
194+
"py": args.py,
195+
},
196+
):
197+
formatted = False
198+
199+
if not formatted:
200+
exit(1)
201+
202+
203+
if __name__ == "__main__":
204+
main()

0 commit comments

Comments
 (0)