Skip to content

Commit fa02939

Browse files
kssilveiratxtpbfmt-copybara-robot
authored andcommitted
No public description
FUTURE_COPYBARA_INTEGRATE_REVIEW=#179 from mhsong21:sort-by-descriptor-field 41497d5 PiperOrigin-RevId: 813239666
1 parent cf07efc commit fa02939

12 files changed

Lines changed: 512 additions & 6 deletions

File tree

ast/ast.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ type Node struct {
7777
// Used when we want to break between the field name and values when a
7878
// single-line node exceeds the requested wrap column.
7979
PutSingleValueOnNextLine bool
80+
// Field number from proto definition (0 if unknown/not applicable).
81+
FieldNumber int32
8082
}
8183

8284
// NodeLess is a sorting function that compares two *Nodes, possibly using the parent Node
@@ -267,6 +269,32 @@ func ByFieldSubfieldPath(field string, subfieldPath []string, projection func(st
267269
}
268270
}
269271

272+
// ByFieldNumber is a NodeLess function that orders fields by their field numbers.
273+
// Field numbers are populated during parsing from descriptor information.
274+
func ByFieldNumber(_, ni, nj *Node, isWholeSlice bool) bool {
275+
if !isWholeSlice {
276+
return false
277+
}
278+
279+
numI, numJ := ni.FieldNumber, nj.FieldNumber
280+
281+
// If both have field numbers, sort by field number
282+
if numI > 0 && numJ > 0 {
283+
return numI < numJ
284+
}
285+
286+
// If only one has field number, prioritize it
287+
if numI > 0 && numJ == 0 {
288+
return true // ni has priority
289+
}
290+
if numI == 0 && numJ > 0 {
291+
return false // nj has priority
292+
}
293+
294+
// If neither has field number, fall back to alphabetical order
295+
return ni.Name < nj.Name
296+
}
297+
270298
// getChildValue returns the Value of the child with the given field name,
271299
// or nil if no single such child exists.
272300
func (n *Node) getChildValue(field string) *Value {

cmd/txtpbfmt/fmt.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ var (
2424
expandAllChildren = flag.Bool("expand_all_children", false, "Expand all children irrespective of initial state.")
2525
skipAllColons = flag.Bool("skip_all_colons", false, "Skip colons whenever possible.")
2626
sortFieldsByFieldName = flag.Bool("sort_fields_by_field_name", false, "Sort fields by field name.")
27+
sortFieldsByFieldNumber = flag.Bool("sort_fields_by_field_number", false, "Sort fields by field number from proto definition.")
28+
protoDescriptor = flag.String("proto_descriptor", "", "Path to protobuf descriptor file (.desc)")
29+
messageFullName = flag.String("message_full_name", "", "Full message type name for field number lookup (required, e.g. google.protobuf.Any)")
2730
sortRepeatedFieldsByContent = flag.Bool("sort_repeated_fields_by_content", false, "Sort adjacent scalar fields of the same field name by their contents.")
2831
sortRepeatedFieldsBySubfield = flag.String("sort_repeated_fields_by_subfield", "", "Sort adjacent message fields of the given field name by the contents of the given subfield.")
2932
removeDuplicateValuesForRepeatedFields = flag.Bool("remove_duplicate_values_for_repeated_fields", false, "Remove lines that have the same field name and scalar value as another.")
@@ -88,6 +91,9 @@ func processPath(path string) error {
8891
ExpandAllChildren: *expandAllChildren,
8992
SkipAllColons: *skipAllColons,
9093
SortFieldsByFieldName: *sortFieldsByFieldName,
94+
SortFieldsByFieldNumber: *sortFieldsByFieldNumber,
95+
ProtoDescriptor: *protoDescriptor,
96+
MessageFullName: *messageFullName,
9197
SortRepeatedFieldsByContent: *sortRepeatedFieldsByContent,
9298
SortRepeatedFieldsBySubfield: strings.Split(*sortRepeatedFieldsBySubfield, ","),
9399
RemoveDuplicateValuesForRepeatedFields: *removeDuplicateValuesForRepeatedFields,

config/config.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ type Config struct {
2424
// Sort fields by field name.
2525
SortFieldsByFieldName bool
2626

27+
// Sort fields by field number from proto definition.
28+
SortFieldsByFieldNumber bool
29+
30+
// Path to protobuf descriptor file (.desc).
31+
ProtoDescriptor string
32+
33+
// Full message type name for field number lookup (required, e.g. google.protobuf.Any).
34+
MessageFullName string
35+
2736
// Sort adjacent scalar fields of the same field name by their contents.
2837
SortRepeatedFieldsByContent bool
2938

descriptor/descriptor.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Package descriptor provides functionality to load and parse Protocol Buffer descriptor files.
2+
package descriptor
3+
4+
import (
5+
"fmt"
6+
"os"
7+
8+
"google.golang.org/protobuf/proto"
9+
"google.golang.org/protobuf/reflect/protodesc"
10+
"google.golang.org/protobuf/reflect/protoreflect"
11+
"google.golang.org/protobuf/reflect/protoregistry"
12+
13+
"google.golang.org/protobuf/types/descriptorpb"
14+
)
15+
16+
// Loader provides functionality to load field numbers from descriptor files.
17+
type Loader struct {
18+
descriptorFile string
19+
files *protoregistry.Files
20+
}
21+
22+
// NewLoader creates a new descriptor loader for the given descriptor file.
23+
func NewLoader(descriptorFile string) (*Loader, error) {
24+
if descriptorFile == "" {
25+
return nil, fmt.Errorf("descriptor file is required")
26+
}
27+
28+
data, err := os.ReadFile(descriptorFile)
29+
if err != nil {
30+
return nil, fmt.Errorf("failed to read descriptor file %s: %v", descriptorFile, err)
31+
}
32+
33+
fileDescSet := &descriptorpb.FileDescriptorSet{}
34+
if err := proto.Unmarshal(data, fileDescSet); err != nil {
35+
return nil, fmt.Errorf("failed to unmarshal descriptor file %s: %v", descriptorFile, err)
36+
}
37+
38+
files, err := protodesc.NewFiles(fileDescSet)
39+
if err != nil {
40+
return nil, fmt.Errorf("failed to create files from descriptor file %s: %v", descriptorFile, err)
41+
}
42+
43+
return &Loader{
44+
descriptorFile: descriptorFile,
45+
files: files,
46+
}, nil
47+
}
48+
49+
// GetRootMessageDescriptor returns the root message descriptor for the specified messageFullName.
50+
// messageFullName is required and must be a valid full name (e.g., "google.protobuf.Any").
51+
func (l *Loader) GetRootMessageDescriptor(messageFullName string) (protoreflect.MessageDescriptor, error) {
52+
if l.files == nil {
53+
return nil, fmt.Errorf("descriptor not loaded, call NewLoader() first")
54+
}
55+
56+
if messageFullName == "" {
57+
// Collect available messages to help user
58+
var availableMessages []string
59+
l.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
60+
messages := fd.Messages()
61+
for i := 0; i < messages.Len(); i++ {
62+
msg := messages.Get(i)
63+
availableMessages = append(availableMessages, string(msg.FullName()))
64+
}
65+
return true
66+
})
67+
68+
if len(availableMessages) == 0 {
69+
return nil, fmt.Errorf("No messages found in descriptor")
70+
}
71+
return nil, fmt.Errorf("message_full_name is required. Available messages: %v", availableMessages)
72+
}
73+
74+
// Find specific message type
75+
desc, err := l.files.FindDescriptorByName(protoreflect.FullName(messageFullName))
76+
if err != nil {
77+
return nil, fmt.Errorf("message type %s not found: %v", messageFullName, err)
78+
}
79+
if msgDesc, ok := desc.(protoreflect.MessageDescriptor); ok {
80+
return msgDesc, nil
81+
}
82+
return nil, fmt.Errorf("%s is not a message type", messageFullName)
83+
}

descriptor/descriptor_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package descriptor
2+
3+
import (
4+
"testing"
5+
6+
// Google internal testing/gobase/runfilestest package, commented out by copybara
7+
)
8+
9+
func TestNewLoader(t *testing.T) {
10+
t.Run("valid descriptor file", func(t *testing.T) {
11+
descriptorFile := "../testdata/test.desc"
12+
loader, err := NewLoader(descriptorFile)
13+
if err != nil {
14+
t.Fatalf("Failed to create loader: %v", err)
15+
}
16+
if loader == nil {
17+
t.Fatal("Expected non-nil loader")
18+
}
19+
})
20+
21+
t.Run("empty descriptor file path", func(t *testing.T) {
22+
_, err := NewLoader("")
23+
if err == nil {
24+
t.Error("Expected error for empty path")
25+
}
26+
})
27+
28+
t.Run("non-existent file", func(t *testing.T) {
29+
_, err := NewLoader("nonexistent.desc")
30+
if err == nil {
31+
t.Error("Expected error for non-existent file")
32+
}
33+
})
34+
}
35+
36+
func TestGetRootMessageDescriptor(t *testing.T) {
37+
descriptorFile := "../testdata/test.desc"
38+
loader, err := NewLoader(descriptorFile)
39+
if err != nil {
40+
t.Fatalf("Failed to create loader: %v", err)
41+
}
42+
43+
tests := []struct {
44+
name string
45+
messageFullName string
46+
wantError bool
47+
}{
48+
{"UserProfile", "testproto.UserProfile", false},
49+
{"ProductCatalog", "testproto.ProductCatalog", false},
50+
{"Level1Config", "testproto.Level1Config", false},
51+
{"nested message", "testproto.Level1Config.Level2Config", false},
52+
{"empty name", "", true},
53+
{"non-existent", "testproto.NonExistent", true},
54+
}
55+
56+
for _, tt := range tests {
57+
t.Run(tt.name, func(t *testing.T) {
58+
desc, err := loader.GetRootMessageDescriptor(tt.messageFullName)
59+
60+
if tt.wantError {
61+
if err == nil {
62+
t.Error("Expected error but got none")
63+
}
64+
} else {
65+
if err != nil {
66+
t.Errorf("Unexpected error: %v", err)
67+
}
68+
if desc == nil {
69+
t.Error("Expected descriptor but got nil")
70+
}
71+
}
72+
})
73+
}
74+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ require (
77
github.com/google/go-cmp v0.6.0
88
github.com/kylelemons/godebug v1.1.0
99
github.com/mitchellh/go-wordwrap v1.0.1
10+
google.golang.org/protobuf v1.33.0
1011
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
66
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
77
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
88
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
9+
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
10+
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=

impl/impl.go

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99
"strconv"
1010
"strings"
1111

12+
"google.golang.org/protobuf/reflect/protoreflect"
1213
"github.com/protocolbuffers/txtpbfmt/ast"
1314
"github.com/protocolbuffers/txtpbfmt/config"
15+
"github.com/protocolbuffers/txtpbfmt/descriptor"
1416
"github.com/protocolbuffers/txtpbfmt/quote"
1517
"github.com/protocolbuffers/txtpbfmt/sort"
1618
"github.com/protocolbuffers/txtpbfmt/wrap"
@@ -148,13 +150,33 @@ func ParseWithMetaCommentConfig(in []byte, c config.Config) ([]*ast.Node, error)
148150
if err != nil {
149151
return nil, err
150152
}
153+
154+
// Load descriptor if field number sorting is enabled
155+
var rootDesc protoreflect.MessageDescriptor
156+
if c.SortFieldsByFieldNumber {
157+
if c.ProtoDescriptor == "" {
158+
return nil, fmt.Errorf("proto_descriptor is required when using sort_fields_by_field_number")
159+
}
160+
161+
loader, err := descriptor.NewLoader(c.ProtoDescriptor)
162+
if err != nil {
163+
return nil, fmt.Errorf("failed to create descriptor loader: %v", err)
164+
}
165+
166+
// Get root message descriptor
167+
rootDesc, err = loader.GetRootMessageDescriptor(c.MessageFullName)
168+
if err != nil {
169+
return nil, fmt.Errorf("failed to get root message descriptor: %v", err)
170+
}
171+
}
172+
151173
if p.config.InfoLevel() {
152174
p.config.Infof("p.in: %q", string(p.in))
153175
p.config.Infof("p.length: %v", p.length)
154176
}
155177
// Although unnamed nodes aren't strictly allowed, some formats represent a
156178
// list of protos as a list of unnamed top-level nodes.
157-
nodes, _, err := p.parse( /*isRoot=*/ true)
179+
nodes, _, err := p.parse( /*isRoot=*/ true, rootDesc)
158180
if err != nil {
159181
return nil, err
160182
}
@@ -288,6 +310,35 @@ func newParser(in []byte, c config.Config) (*parser, error) {
288310
return parser, nil
289311
}
290312

313+
// getFieldNumber returns the field number for a given field name in the descriptor.
314+
func getFieldNumber(desc protoreflect.MessageDescriptor, fieldName string) int32 {
315+
if desc == nil {
316+
return 0
317+
}
318+
319+
field := desc.Fields().ByTextName(fieldName)
320+
if field == nil {
321+
return 0
322+
}
323+
return int32(field.Number())
324+
}
325+
326+
// findChildDescriptor finds the descriptor for a nested message field.
327+
func (p *parser) findChildDescriptor(desc protoreflect.MessageDescriptor, fieldName string) protoreflect.MessageDescriptor {
328+
if desc == nil {
329+
return nil
330+
}
331+
332+
field := desc.Fields().ByTextName(fieldName)
333+
if field == nil {
334+
return nil
335+
}
336+
if field.Kind() == protoreflect.MessageKind {
337+
return field.Message()
338+
}
339+
return nil
340+
}
341+
291342
func (p *parser) nextInputIs(b byte) bool {
292343
return p.index < p.length && p.in[p.index] == b
293344
}
@@ -398,7 +449,7 @@ func (p *parser) consumeOptionalSeparator() error {
398449
// format (sequence of messages, each of which passes proto.UnmarshalText()).
399450
// endPos is the position of the first character on the first line
400451
// after parsed nodes: that's the position to append more children.
401-
func (p *parser) parse(isRoot bool) (result []*ast.Node, endPos ast.Position, err error) {
452+
func (p *parser) parse(isRoot bool, desc protoreflect.MessageDescriptor) (result []*ast.Node, endPos ast.Position, err error) {
402453
var res []*ast.Node
403454
res = []*ast.Node{} // empty children is different from nil children
404455
for ld := p.getLoopDetector(); p.index < p.length; {
@@ -505,14 +556,17 @@ func (p *parser) parse(isRoot bool) (result []*ast.Node, endPos ast.Position, er
505556
return nil, ast.Position{}, err
506557
}
507558

559+
// Set field number from descriptor if available
560+
nd.FieldNumber = getFieldNumber(desc, nd.Name)
561+
508562
// Skip separator.
509563
preCommentsBeforeColon, _ := p.skipWhiteSpaceAndReadComments(true /* multiLine */)
510564
nd.SkipColon = !p.consume(':')
511565
previousPos := p.position()
512566
preCommentsAfterColon, _ := p.skipWhiteSpaceAndReadComments(true /* multiLine */)
513567

514568
if p.consume('{') || p.consume('<') {
515-
if err := p.parseMessage(nd); err != nil {
569+
if err := p.parseMessage(nd, desc); err != nil {
516570
return nil, ast.Position{}, err
517571
}
518572
} else if p.consume('[') {
@@ -562,14 +616,15 @@ func (p *parser) parseFieldName(nd *ast.Node, isRoot bool) error {
562616
return nil
563617
}
564618

565-
func (p *parser) parseMessage(nd *ast.Node) error {
619+
func (p *parser) parseMessage(nd *ast.Node, desc protoreflect.MessageDescriptor) error {
566620
if p.config.SkipAllColons {
567621
nd.SkipColon = true
568622
}
569623
nd.ChildrenSameLine = p.bracketSameLine[p.index-1]
570624
nd.IsAngleBracket = p.config.PreserveAngleBrackets && p.in[p.index-1] == '<'
571625
// Recursive call to parse child nodes.
572-
nodes, lastPos, err := p.parse( /*isRoot=*/ false)
626+
childDesc := p.findChildDescriptor(desc, nd.Name)
627+
nodes, lastPos, err := p.parse( /*isRoot=*/ false, childDesc)
573628
if err != nil {
574629
return err
575630
}
@@ -595,7 +650,7 @@ func (p *parser) parseList(nd *ast.Node, preCommentsBeforeColon, preCommentsAfte
595650
// Handle list of nodes.
596651
nd.ChildrenAsList = true
597652

598-
nodes, lastPos, err := p.parse( /*isRoot=*/ true)
653+
nodes, lastPos, err := p.parse( /*isRoot=*/ true, nil)
599654
if err != nil {
600655
return err
601656
}

0 commit comments

Comments
 (0)