Skip to content

Commit 291807a

Browse files
kalbasitclaude
andauthored
feat: fix Go acronym casing in generated code (#13)
This adds a post-processing step to fix common Go naming convention issues in generated code. The fixAcronyms function corrects acronyms like Api->API, Id->ID, Sql->SQL, Url->URL, etc. This ensures generated Go code follows standard Go naming conventions for acronyms. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7fd9eca commit 291807a

5 files changed

Lines changed: 211 additions & 5 deletions

File tree

example/pkg/database/generated_wrapper_mysql.go

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

example/pkg/database/mysqldb/query.mysql.sql.go

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

example/pkg/database/sqlitedb/query.sqlite.sql.go

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

generator/helpers.go

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"log"
66
"os"
77
"path/filepath"
8+
"regexp"
89
"strings"
910
"unicode"
1011

@@ -52,6 +53,66 @@ func extractBulkFor(comment string) string {
5253

5354
func toSingular(s string) string { return inflection.Singular(s) }
5455

56+
// FixAcronyms corrects common Go acronym casing issues using word-boundary-aware
57+
// regex replacements to avoid corrupting words that contain acronyms as substrings.
58+
// For example: Id -> ID, Api -> API, Sql -> SQL, Url -> URL, Xml -> XML.
59+
func FixAcronyms(content []byte) []byte {
60+
// Common Go acronyms that should be all caps, with their correct form.
61+
acronyms := []struct {
62+
pattern string
63+
replacement string
64+
}{
65+
{"Acl", "ACL"},
66+
{"Api", "API"},
67+
{"Cpu", "CPU"},
68+
{"Ec2", "EC2"},
69+
{"Ebs", "EBS"},
70+
{"Html", "HTML"},
71+
{"Id", "ID"},
72+
{"Io", "IO"},
73+
{"Ip", "IP"},
74+
{"Json", "JSON"},
75+
{"Jwt", "JWT"},
76+
{"Sql", "SQL"},
77+
{"Ssh", "SSH"},
78+
{"Tcp", "TCP"},
79+
{"Tls", "TLS"},
80+
{"Udp", "UDP"},
81+
{"Url", "URL"},
82+
{"Xml", "XML"},
83+
}
84+
85+
result := string(content)
86+
87+
for _, a := range acronyms {
88+
// Pre-compile regexes once per acronym (not inside inner loop).
89+
// Use patterns to handle acronyms in different positions:
90+
// 1. Start: `^(Acronym)([A-Z])` - acronym at start followed by uppercase, e.g., Xml in XMLParser.
91+
// 2. AfterUpper: `([A-Z])(Acronym)([A-Z])` - acronym between uppercase, e.g., Html in UserHTMLDoc.
92+
// 3. Mid: `([a-z])(Acronym)([A-Z])` - acronym in middle of camelCase, e.g., Id in userIdMore.
93+
// 4. End: `([a-z])(Acronym)$` - acronym at end of identifier, e.g., Id in userId.
94+
// 5. NonLetter: `([a-z])(Acronym)([^A-Za-z])` - acronym followed by non-letter.
95+
regexStart := regexp.MustCompile(`^(` + a.pattern + `)([A-Z])`)
96+
regexAfterUpper := regexp.MustCompile(`([A-Z])(` + a.pattern + `)([A-Z])`)
97+
regexMid := regexp.MustCompile(`([a-z])(` + a.pattern + `)([A-Z])`)
98+
regexEnd := regexp.MustCompile(`([a-z])(` + a.pattern + `)$`)
99+
regexNonLetter := regexp.MustCompile(`([a-z])(` + a.pattern + `)([^A-Za-z])`)
100+
101+
// Start case: replace with replacement followed by ${2} (the uppercase after).
102+
result = regexStart.ReplaceAllString(result, a.replacement+"${2}")
103+
// After uppercase case: preserve surrounding uppercase via ${1} and ${3}.
104+
result = regexAfterUpper.ReplaceAllString(result, "${1}"+a.replacement+"${3}")
105+
// Middle case: preserve the following uppercase letter via ${3}.
106+
result = regexMid.ReplaceAllString(result, "${1}"+a.replacement+"${3}")
107+
// Non-letter case: preserve the following character via ${3}.
108+
result = regexNonLetter.ReplaceAllString(result, "${1}"+a.replacement+"${3}")
109+
// End case: no ${3} since there's no following letter.
110+
result = regexEnd.ReplaceAllString(result, "${1}"+a.replacement)
111+
}
112+
113+
return []byte(result)
114+
}
115+
55116
func writeFile(dir, filename string, content []byte) {
56117
// 1. Manage imports with goimports
57118
withImports, err := imports.Process(filename, content, nil)
@@ -70,7 +131,10 @@ func writeFile(dir, filename string, content []byte) {
70131
log.Fatalf("formatting %s: %v", filename, err)
71132
}
72133

73-
if err := os.WriteFile(filepath.Join(dir, filename), formatted, 0o644); err != nil { //nolint:gosec
134+
// 3. Fix acronym casing (Api -> API, Id -> ID, etc.)
135+
fixed := FixAcronyms(formatted)
136+
137+
if err := os.WriteFile(filepath.Join(dir, filename), fixed, 0o644); err != nil { //nolint:gosec
74138
log.Fatal(err)
75139
}
76140

generator/helpers_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package generator_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/kalbasit/sqlc-multi-db/generator"
7+
)
8+
9+
func TestFixAcronyms(t *testing.T) {
10+
t.Parallel()
11+
12+
tests := []struct {
13+
name string
14+
input string
15+
expected string
16+
}{
17+
{
18+
name: "Simple Id to ID",
19+
input: "userId",
20+
expected: "userID",
21+
},
22+
{
23+
name: "Url to URL",
24+
input: "profileUrl",
25+
expected: "profileURL",
26+
},
27+
{
28+
name: "Api to API",
29+
input: "fetchApiData",
30+
expected: "fetchAPIData",
31+
},
32+
{
33+
name: "Sql to SQL",
34+
input: "execSqlQuery",
35+
expected: "execSQLQuery",
36+
},
37+
{
38+
name: "Json to JSON",
39+
input: "parseJsonBody",
40+
expected: "parseJSONBody",
41+
},
42+
{
43+
name: "Identifier should not be corrupted",
44+
input: "Identifier",
45+
expected: "Identifier",
46+
},
47+
{
48+
name: "Curling should not be corrupted to CURLing",
49+
input: "Curling",
50+
expected: "Curling",
51+
},
52+
{
53+
name: "XmlParser should become XMLParser",
54+
input: "XmlParser",
55+
expected: "XMLParser",
56+
},
57+
{
58+
name: "HtmlDocument should become HTMLDocument",
59+
input: "HtmlDocument",
60+
expected: "HTMLDocument",
61+
},
62+
{
63+
name: "Multiple acronyms in one string",
64+
input: "userId and profileUrl",
65+
expected: "userID and profileURL",
66+
},
67+
{
68+
name: "Id at end of string",
69+
input: "GetUserId",
70+
expected: "GetUserID",
71+
},
72+
{
73+
name: "Already correct ID should stay",
74+
input: "GetUserID",
75+
expected: "GetUserID",
76+
},
77+
{
78+
name: "Already correct URL should stay",
79+
input: "profileURL",
80+
expected: "profileURL",
81+
},
82+
{
83+
name: "Tcp connection",
84+
input: "openTcpConnection",
85+
expected: "openTCPConnection",
86+
},
87+
{
88+
name: "Jwt token",
89+
input: "validateJwtToken",
90+
expected: "validateJWTToken",
91+
},
92+
{
93+
name: "Ec2 instance",
94+
input: "launchEc2Instance",
95+
expected: "launchEC2Instance",
96+
},
97+
{
98+
name: "Json web token Jwt",
99+
input: "parseJsonWebToken",
100+
expected: "parseJSONWebToken",
101+
},
102+
{
103+
name: "Uuid should not change (not in our list)",
104+
input: "userUuid",
105+
expected: "userUuid",
106+
},
107+
{
108+
name: "Empty string",
109+
input: "",
110+
expected: "",
111+
},
112+
}
113+
114+
for _, tt := range tests {
115+
t.Run(tt.name, func(t *testing.T) {
116+
t.Parallel()
117+
118+
result := generator.FixAcronyms([]byte(tt.input))
119+
if string(result) != tt.expected {
120+
t.Errorf("FixAcronyms(%q) = %q, want %q", tt.input, string(result), tt.expected)
121+
}
122+
})
123+
}
124+
}

0 commit comments

Comments
 (0)