Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This module now requires [Go 1.21](https://go.dev/doc/go1.21) or higher.

### Added
- Added support for [buf.build](https://buf.build/alta/protopatch).
- Added support for running tools with Go1.24's `go tool` by specifying tool=true

### Notes
- Changelog from here forward will only include major dependency updates.
Expand Down
15 changes: 11 additions & 4 deletions cmd/protoc-gen-go-patch/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package main

import (
"fmt"
"io/ioutil"
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/alta/protopatch/patch"
Expand Down Expand Up @@ -36,12 +37,18 @@ func run() error {
}

var plugin string
var useGoTool bool

opts := protogen.Options{
ParamFunc: func(name, value string) error {
switch name {
case "plugin":
plugin = value
case "tool":
useGoTool, err = strconv.ParseBool(value)
if err != nil {
return err
}
}
return nil // Ignore unknown params.
},
Expand All @@ -58,14 +65,14 @@ func run() error {
}

if os.Getenv("PROTO_PATCH_DEBUG_LOGGING") == "" {
log.SetOutput(ioutil.Discard)
log.SetOutput(io.Discard)
}

// Strip our custom param(s).
patch.StripParam(gen.Request, "plugin")
patch.StripParams(gen.Request, []string{"plugin", "tool"})

// Run the specified plugin and unmarshal the CodeGeneratorResponse.
res, err := patch.RunPlugin(plugin, gen.Request, nil)
res, err := patch.RunPlugin(plugin, gen.Request, nil, useGoTool)
if err != nil {
return err
}
Expand Down
26 changes: 17 additions & 9 deletions patch/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@ package patch
import (
"bytes"
"io"
"io/ioutil"
"os"
"os/exec"
"slices"
"strings"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/pluginpb"
)

// StripParam strips a named param from req.
func StripParam(req *pluginpb.CodeGeneratorRequest, p string) {
// StripParams strips a named param from req.
func StripParams(req *pluginpb.CodeGeneratorRequest, p []string) {
if req.Parameter == nil {
return
}
v := stripParam(*req.Parameter, p)
v := stripParams(*req.Parameter, p)
req.Parameter = &v

}

func stripParam(s, p string) string {
func stripParams(s string, p []string) string {
var b strings.Builder
for _, param := range strings.Split(s, ",") {
if strings.SplitN(param, "=", 2)[0] != p {
if !slices.Contains(p, strings.SplitN(param, "=", 2)[0]) {
if b.Len() > 0 {
b.WriteString(",")
}
Expand All @@ -38,7 +38,7 @@ func stripParam(s, p string) string {
// RunPlugin runs a protoc plugin named "protoc-gen-$plugin"
// and returns the generated CodeGeneratorResponse or an error.
// Supply a non-nil stderr to override stderr on the called plugin.
func RunPlugin(plugin string, req *pluginpb.CodeGeneratorRequest, stderr io.Writer) (*pluginpb.CodeGeneratorResponse, error) {
func RunPlugin(plugin string, req *pluginpb.CodeGeneratorRequest, stderr io.Writer, useGoTool bool) (*pluginpb.CodeGeneratorResponse, error) {
if stderr == nil {
stderr = os.Stderr
}
Expand All @@ -50,8 +50,16 @@ func RunPlugin(plugin string, req *pluginpb.CodeGeneratorRequest, stderr io.Writ
}

// Call the plugin with the modified CodeGeneratorRequest.
cmdName := "protoc-gen-" + plugin
var args []string

if useGoTool {
args = []string{"tool", cmdName}
cmdName = "go"
}

var buf bytes.Buffer
cmd := exec.Command("protoc-gen-" + plugin)
cmd := exec.Command(cmdName, args...)
cmd.Stdin = bytes.NewReader(b)
cmd.Stdout = &buf
cmd.Stderr = stderr
Expand All @@ -71,7 +79,7 @@ func RunPlugin(plugin string, req *pluginpb.CodeGeneratorRequest, stderr io.Writ

// ReadRequest reads and unmarshals a CodeGeneratorRequest.
func ReadRequest(r io.Reader) (*pluginpb.CodeGeneratorRequest, error) {
in, err := ioutil.ReadAll(os.Stdin)
in, err := io.ReadAll(r)
if err != nil {
return nil, err
}
Expand Down