Skip to content

Commit de48132

Browse files
Support SQL formatting inside textproto string fields marked with '# txtpbfmt sql-format'.
PiperOrigin-RevId: 869641916
1 parent fcb97cc commit de48132

4 files changed

Lines changed: 193 additions & 2 deletions

File tree

impl/impl.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/protocolbuffers/txtpbfmt/descriptor"
1616
"github.com/protocolbuffers/txtpbfmt/quote"
1717
"github.com/protocolbuffers/txtpbfmt/sort"
18+
"github.com/protocolbuffers/txtpbfmt/sql"
1819
"github.com/protocolbuffers/txtpbfmt/wrap"
1920
)
2021

@@ -188,6 +189,9 @@ func ParseWithMetaCommentConfig(in []byte, c config.Config) ([]*ast.Node, error)
188189
if p.index < p.length {
189190
return nil, fmt.Errorf("parser didn't consume all input. Stopped at %s", p.errorContext())
190191
}
192+
if err := sql.Format(nodes); err != nil {
193+
return nil, err
194+
}
191195
if err := wrap.Strings(nodes, 0, c); err != nil {
192196
return nil, err
193197
}
@@ -201,7 +205,7 @@ func ParseWithMetaCommentConfig(in []byte, c config.Config) ([]*ast.Node, error)
201205
// have the equal sign. Currently there are only two MetaComments that are in the former format:
202206
//
203207
// "sort_repeated_fields_by_subfield": If this appears multiple times, then they will all be added
204-
// to the config and the order is perserved.
208+
// to the config and the order is preserved.
205209
// "wrap_strings_at_column": The <val> is expected to be an integer. If it is not, then it will be
206210
// ignored. If this appears multiple times, only the last one saved.
207211
func addToConfig(metaComment string, c *config.Config) error {

sql.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Package sql provides functions for formatting SQL inside textproto string fields.
2+
package sql
3+
4+
import (
5+
"strconv"
6+
"strings"
7+
8+
"google3/storage/googlesql/public/go/lenientformatter"
9+
"github.com/protocolbuffers/txtpbfmt/ast"
10+
"github.com/protocolbuffers/txtpbfmt/unquote"
11+
)
12+
13+
const sqlFormatComment = "# txtpbfmt sql-format"
14+
15+
// Format formats SQL strings in the given nodes.
16+
func Format(nodes []*ast.Node) error {
17+
for _, nd := range nodes {
18+
if nd.ChildrenSameLine {
19+
continue
20+
}
21+
if err := formatNode(nd); err != nil {
22+
return err
23+
}
24+
if err := Format(nd.Children); err != nil {
25+
return err
26+
}
27+
}
28+
return nil
29+
}
30+
31+
func formatNode(nd *ast.Node) error {
32+
hasComment := false
33+
for _, c := range nd.PreComments {
34+
if strings.TrimSpace(c) == sqlFormatComment {
35+
hasComment = true
36+
break
37+
}
38+
}
39+
if !hasComment {
40+
return nil
41+
}
42+
if len(nd.Values) == 0 {
43+
return nil
44+
}
45+
46+
// Unquote the string(s).
47+
sql, _, err := unquote.Unquote(nd)
48+
if err != nil {
49+
// If it's not a valid string, we can't format it as SQL.
50+
return nil
51+
}
52+
53+
formatted, err := lenientformatter.FormatSQL(sql)
54+
if err != nil {
55+
// If SQL formatting fails, we keep the original.
56+
return nil
57+
}
58+
59+
lines := strings.Split(formatted, "\n")
60+
// If there's a trailing newline, Split might produce an extra empty line.
61+
if len(lines) > 1 && lines[len(lines)-1] == "" {
62+
lines = lines[:len(lines)-1]
63+
}
64+
65+
newValues := make([]*ast.Value, 0, len(lines))
66+
for i, line := range lines {
67+
val := line
68+
if i < len(lines)-1 {
69+
// Add a space instead of a newline to separate SQL lines.
70+
// This results in a cleaner look in the textproto.
71+
if !strings.HasSuffix(val, " ") {
72+
val += " "
73+
}
74+
}
75+
newValues = append(newValues, &ast.Value{
76+
Value: strconv.Quote(val),
77+
})
78+
}
79+
80+
// Preserve comments if possible.
81+
if len(nd.Values) > 0 {
82+
if len(nd.Values[0].PreComments) > 0 {
83+
newValues[0].PreComments = nd.Values[0].PreComments
84+
}
85+
lastIdx := len(nd.Values) - 1
86+
if len(nd.Values[lastIdx].InlineComment) > 0 {
87+
newValues[len(newValues)-1].InlineComment = nd.Values[lastIdx].InlineComment
88+
}
89+
}
90+
91+
nd.Values = newValues
92+
if len(lines) > 1 {
93+
nd.PutSingleValueOnNextLine = true
94+
}
95+
96+
return nil
97+
}

sql_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package sql_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/kylelemons/godebug/diff"
7+
"github.com/protocolbuffers/txtpbfmt/parser"
8+
"github.com/protocolbuffers/txtpbfmt/sql"
9+
)
10+
11+
func TestFormat(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
in string
15+
want string
16+
}{
17+
{
18+
name: "simple SQL",
19+
in: `# txtpbfmt sql-format
20+
query: "select * from table where id = 1"
21+
`,
22+
want: `# txtpbfmt sql-format
23+
query: "SELECT * FROM table WHERE id = 1"
24+
`,
25+
},
26+
{
27+
name: "multiline SQL",
28+
in: `# txtpbfmt sql-format
29+
query: "select a, b, c from table1 join table2 on table1.id = table2.id where a > 10 group by 1, 2, 3 order by 1"
30+
`,
31+
want: `# txtpbfmt sql-format
32+
query:
33+
"SELECT a, b, c "
34+
"FROM table1 "
35+
"JOIN table2 "
36+
" ON table1.id = table2.id "
37+
"WHERE a > 10 "
38+
"GROUP BY 1, 2, 3 "
39+
"ORDER BY 1"
40+
`,
41+
},
42+
{
43+
name: "no comment, no format",
44+
in: `query: "select * from table"
45+
`,
46+
want: `query: "select * from table"
47+
`,
48+
},
49+
{
50+
name: "wrong comment, no format",
51+
in: `# some other comment
52+
query: "select * from table"
53+
`,
54+
want: `# some other comment
55+
query: "select * from table"
56+
`,
57+
},
58+
}
59+
60+
for _, tc := range tests {
61+
t.Run(tc.name, func(t *testing.T) {
62+
nodes, err := parser.Parse([]byte(tc.in))
63+
if err != nil {
64+
t.Fatalf("Parse() err = %v", err)
65+
}
66+
if err := sql.Format(nodes); err != nil {
67+
t.Fatalf("Format() err = %v", err)
68+
}
69+
got := parser.Pretty(nodes, 0)
70+
if diff := diff.Diff(tc.want, got); diff != "" {
71+
t.Errorf("Format() diff (-want +got):\n%s", diff)
72+
}
73+
})
74+
}
75+
}

wrap/wrap.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import (
1414

1515
var tagRegex = regexp.MustCompile(`<.*>`)
1616

17-
const indentSpaces = " "
17+
const (
18+
indentSpaces = " "
19+
sqlFormatComment = "# txtpbfmt sql-format"
20+
)
1821

1922
// Strings wraps the strings in the given nodes.
2023
func Strings(nodes []*ast.Node, depth int, c config.Config) error {
@@ -25,6 +28,9 @@ func Strings(nodes []*ast.Node, depth int, c config.Config) error {
2528
if nd.ChildrenSameLine {
2629
continue
2730
}
31+
if hasSQLFormatComment(nd) {
32+
continue
33+
}
2834
if err := wrapNodeStrings(nd, depth, c); err != nil {
2935
return err
3036
}
@@ -35,6 +41,15 @@ func Strings(nodes []*ast.Node, depth int, c config.Config) error {
3541
return nil
3642
}
3743

44+
func hasSQLFormatComment(nd *ast.Node) bool {
45+
for _, c := range nd.PreComments {
46+
if strings.TrimSpace(c) == sqlFormatComment {
47+
return true
48+
}
49+
}
50+
return false
51+
}
52+
3853
func wrapNodeStrings(nd *ast.Node, depth int, c config.Config) error {
3954
if c.WrapStringsAtColumn > 0 && needsWrappingAtColumn(nd, depth, c) {
4055
if err := wrapLinesAtColumn(nd, depth, c); err != nil {

0 commit comments

Comments
 (0)