Skip to content

Commit 83cfd8a

Browse files
committed
Add gRPC support to mains plugin
This change enhances the mains plugin to support gRPC servers in addition to HTTP servers, addressing issue #209. Key changes: - Only generate HTTP server code when the service design includes HTTP endpoints - Only generate gRPC server code when the service design includes gRPC endpoints - Always generate the metrics HTTP server (health/metrics/debug endpoints) - Fix duplicate package/import declarations in generated main.go - Fix WebSocket detection bug (Stream value of 0 means no streaming, not NoStreamKind) - Conditionally import transport-specific packages based on what's actually used - Add proper OTel instrumentation for gRPC servers - Implement graceful shutdown for both HTTP and gRPC servers The plugin now: 1. Scans the DSL for HTTP and gRPC endpoint definitions 2. Conditionally generates server initialization code based on transports used 3. Manages imports efficiently (only includes websocket, grpc packages when needed) 4. Provides a unified main.go that can run HTTP-only, gRPC-only, or both transports 5. Maintains the existing single-service (services/<svc>/cmd/<svc>) and multi-service (cmd/<server>) layouts Fixes #209
1 parent 11a756c commit 83cfd8a

2 files changed

Lines changed: 158 additions & 40 deletions

File tree

mains/generate.go

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ type srvInfo struct {
3131
APIPkg string
3232
Services []*service.Data
3333
HasWS bool
34+
HasHTTP bool
35+
HasGRPC bool
3436
ServerName string
3537
}
3638

@@ -43,7 +45,11 @@ type svcT struct {
4345
SrvVar string
4446
GenPkg string
4547
GenHTTPPkg string
48+
GenGRPCPkg string
49+
GenGRPCPbPkg string
4650
HasWebSocket bool
51+
HasHTTP bool
52+
HasGRPC bool
4753
}
4854

4955
// Register the plugin for the example phase.
@@ -109,21 +115,32 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
109115
apipkg := apiPkgAlias(genpkg, roots)
110116
if info, ok := srvMap[dir]; ok {
111117
info.HasWS = hasWS
118+
info.HasHTTP = true
112119
if info.APIPkg == "" { info.APIPkg = apipkg }
113120
} else {
114-
srvMap[dir] = &srvInfo{Dir: dir, APIPkg: apipkg, Services: svcs, HasWS: hasWS}
121+
srvMap[dir] = &srvInfo{Dir: dir, APIPkg: apipkg, Services: svcs, HasWS: hasWS, HasHTTP: true}
122+
}
123+
}
124+
// Detect gRPC servers from grpc.go files
125+
for _, f := range files {
126+
if filepath.Base(f.Path) != "grpc.go" { continue }
127+
segs := strings.Split(filepath.ToSlash(f.Path), "/")
128+
if len(segs) < 3 || segs[0] != "cmd" { continue }
129+
dir := segs[1]
130+
if info, ok := srvMap[dir]; ok {
131+
info.HasGRPC = true
115132
}
116133
}
117134

118135
if len(srvMap) == 0 {
119136
return files, nil
120137
}
121138

122-
// Filter out default example mains and http.go; we'll add our own mains.
139+
// Filter out default example mains, http.go, and grpc.go; we'll add our own mains.
123140
var out []*codegen.File
124141
for _, f := range files {
125142
base := filepath.Base(f.Path)
126-
if strings.HasPrefix(f.Path, "cmd/") && (base == "main.go" || base == "http.go") {
143+
if strings.HasPrefix(f.Path, "cmd/") && (base == "main.go" || base == "http.go" || base == "grpc.go") {
127144
continue
128145
}
129146
out = append(out, f)
@@ -156,6 +173,14 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
156173
codegen.GoaNamedImport("http", "goahttp"),
157174
{Path: "google.golang.org/grpc/credentials/insecure"},
158175
}
176+
if info.HasGRPC {
177+
specs = append(specs,
178+
&codegen.ImportSpec{Path: "net"},
179+
&codegen.ImportSpec{Path: "google.golang.org/grpc"},
180+
&codegen.ImportSpec{Path: "google.golang.org/grpc/reflection"},
181+
&codegen.ImportSpec{Path: "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"},
182+
)
183+
}
159184
if info.HasWS {
160185
specs = append(specs, &codegen.ImportSpec{Path: "github.com/gorilla/websocket"})
161186
}
@@ -164,19 +189,45 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
164189

165190
scope := codegen.NewNameScope()
166191
var svcsData []svcT
192+
httpBySvc := httpServicesByName(roots)
193+
grpcBySvc := grpcServicesByName(roots)
167194
wsBySvc := httpWebSocketByService(roots)
168195
hasAnyWS := false
196+
hasAnyHTTP := false
197+
hasAnyGRPC := false
169198
for _, sd := range info.Services {
170199
genAlias := scope.Unique(sd.PkgName, "svc")
171-
httpAlias := scope.Unique(sd.PkgName+"svr", "svr")
172-
specs = append(specs,
173-
&codegen.ImportSpec{Path: path.Join(genpkg, sd.PathName), Name: genAlias},
174-
&codegen.ImportSpec{Path: path.Join(genpkg, "http", sd.PathName, "server"), Name: httpAlias},
175-
)
200+
hasHTTP := httpBySvc[sd.Name]
201+
hasGRPC := grpcBySvc[sd.Name]
176202
hws := wsBySvc[sd.Name]
203+
204+
var httpAlias, grpcAlias, grpcPbAlias string
205+
206+
// Always add the base service package
207+
specs = append(specs, &codegen.ImportSpec{Path: path.Join(genpkg, sd.PathName), Name: genAlias})
208+
209+
// Conditionally add HTTP server imports
210+
if hasHTTP {
211+
httpAlias = scope.Unique(sd.PkgName+"svr", "svr")
212+
specs = append(specs, &codegen.ImportSpec{Path: path.Join(genpkg, "http", sd.PathName, "server"), Name: httpAlias})
213+
hasAnyHTTP = true
214+
}
215+
216+
// Conditionally add gRPC server imports
217+
if hasGRPC {
218+
grpcAlias = scope.Unique(sd.PkgName+"grpc", "grpcsvc")
219+
grpcPbAlias = scope.Unique(sd.PkgName+"pb", "pb")
220+
specs = append(specs,
221+
&codegen.ImportSpec{Path: path.Join(genpkg, "grpc", sd.PathName, "server"), Name: grpcAlias},
222+
&codegen.ImportSpec{Path: path.Join(genpkg, "grpc", sd.PathName, "pb"), Name: grpcPbAlias},
223+
)
224+
hasAnyGRPC = true
225+
}
226+
177227
if hws {
178228
hasAnyWS = true
179229
}
230+
180231
svcsData = append(svcsData, svcT{
181232
Name: sd.Name,
182233
StructName: sd.StructName,
@@ -185,7 +236,11 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
185236
SrvVar: sd.VarName + "Server",
186237
GenPkg: genAlias,
187238
GenHTTPPkg: httpAlias,
239+
GenGRPCPkg: grpcAlias,
240+
GenGRPCPbPkg: grpcPbAlias,
188241
HasWebSocket: hws,
242+
HasHTTP: hasHTTP,
243+
HasGRPC: hasGRPC,
189244
})
190245
}
191246

@@ -195,6 +250,8 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
195250
"APIPkg": info.APIPkg,
196251
"Services": svcsData,
197252
"HasAnyWebSocket": hasAnyWS,
253+
"HasHTTP": hasAnyHTTP,
254+
"HasGRPC": hasAnyGRPC,
198255
"ServiceCount": len(svcsData),
199256
"ServerLabel": serverLabel(roots),
200257
}},
@@ -265,7 +322,8 @@ func httpWebSocketByService(roots []eval.Root) map[string]bool {
265322
if e.SSE != nil {
266323
continue
267324
}
268-
if e.MethodExpr != nil && e.MethodExpr.Stream != expr.NoStreamKind {
325+
// Stream is 0 when no streaming is defined, and >= NoStreamKind (1) when streaming is used
326+
if e.MethodExpr != nil && e.MethodExpr.Stream != 0 {
269327
hasWS[svc.Name()] = true
270328
break
271329
}
@@ -286,3 +344,37 @@ func rootServer(roots []eval.Root) *expr.ServerExpr {
286344
}
287345
return nil
288346
}
347+
348+
// httpServicesByName returns map of service names that have HTTP endpoints.
349+
func httpServicesByName(roots []eval.Root) map[string]bool {
350+
hasHTTP := map[string]bool{}
351+
for _, r := range roots {
352+
root, ok := r.(*expr.RootExpr)
353+
if !ok || root.API == nil || root.API.HTTP == nil {
354+
continue
355+
}
356+
for _, svc := range root.API.HTTP.Services {
357+
if len(svc.HTTPEndpoints) > 0 {
358+
hasHTTP[svc.Name()] = true
359+
}
360+
}
361+
}
362+
return hasHTTP
363+
}
364+
365+
// grpcServicesByName returns map of service names that have gRPC endpoints.
366+
func grpcServicesByName(roots []eval.Root) map[string]bool {
367+
hasGRPC := map[string]bool{}
368+
for _, r := range roots {
369+
root, ok := r.(*expr.RootExpr)
370+
if !ok || root.API == nil || root.API.GRPC == nil {
371+
continue
372+
}
373+
for _, svc := range root.API.GRPC.Services {
374+
if len(svc.GRPCEndpoints) > 0 {
375+
hasGRPC[svc.Name()] = true
376+
}
377+
}
378+
}
379+
return hasGRPC
380+
}

mains/templates/main.go.tpl

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,11 @@
1-
package main
2-
3-
import (
4-
"context"
5-
"flag"
6-
"fmt"
7-
"net/http"
8-
"net/http/httptrace"
9-
"os"
10-
"os/signal"
11-
"sync"
12-
"syscall"
13-
"time"
14-
15-
"go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace"
16-
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
17-
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
18-
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
19-
"goa.design/clue/clue"
20-
"goa.design/clue/debug"
21-
"goa.design/clue/health"
22-
"goa.design/clue/log"
23-
goahttp "goa.design/goa/v3/http"
24-
{{- if .HasAnyWebSocket }}
25-
"github.com/gorilla/websocket"
26-
{{- end }}
27-
"google.golang.org/grpc/credentials/insecure"
28-
)
29-
301
func main() {
312
var (
3+
{{- if .HasHTTP }}
324
httpaddr = flag.String("http-addr", ":8080", "HTTP listen address")
5+
{{- end }}
6+
{{- if .HasGRPC }}
7+
grpcaddr = flag.String("grpc-addr", ":9090", "gRPC listen address")
8+
{{- end }}
339
metricsAddr = flag.String("metrics-addr", ":8081", "metrics listen address")
3410
coladdr = flag.String("otel-addr", ":4317", "OpenTelemetry collector listen address")
3511
debugf = flag.Bool("debug", false, "Enable debug logs")
@@ -126,6 +102,7 @@ func main() {
126102
{{ .EpVar }}.Use(log.Endpoint)
127103
{{- end }}
128104

105+
{{- if .HasHTTP }}
129106
// 6. Create HTTP transport
130107
mux := goahttp.NewMuxer()
131108
debug.MountDebugLogEnabler(debug.Adapt(mux))
@@ -139,6 +116,7 @@ func main() {
139116
{{- end }}
140117

141118
{{- range .Services }}
119+
{{- if .HasHTTP }}
142120
// {{ .Name }} HTTP server
143121
{{- if .HasWebSocket }}
144122
{{ .SrvVar }} := {{ .GenHTTPPkg }}.New({{ .EpVar }}, mux, goahttp.RequestDecoder, goahttp.ResponseEncoder, nil, nil, upgrader, nil)
@@ -150,10 +128,37 @@ func main() {
150128
log.Print(ctx, log.KV{K: "method", V: m.Method}, log.KV{K: "endpoint", V: m.Verb + " " + m.Pattern})
151129
}
152130
{{- end }}
131+
{{- end }}
153132

154133
httpServer := &http.Server{Addr: *httpaddr, Handler: handler}
134+
{{- end }}
135+
136+
{{- if .HasGRPC }}
137+
// 6b. Create gRPC server with interceptors
138+
var grpcServerOpts []grpc.ServerOption
139+
grpcServerOpts = append(grpcServerOpts, grpc.StatsHandler(otelgrpc.NewServerHandler()))
140+
grpcServerOpts = append(grpcServerOpts, grpc.ChainUnaryInterceptor(
141+
log.UnaryServerInterceptor(ctx),
142+
debug.UnaryServerInterceptor(),
143+
))
144+
grpcServerOpts = append(grpcServerOpts, grpc.ChainStreamInterceptor(
145+
log.StreamServerInterceptor(ctx),
146+
debug.StreamServerInterceptor(),
147+
))
148+
grpcServer := grpc.NewServer(grpcServerOpts...)
149+
150+
{{- range .Services }}
151+
{{- if .HasGRPC }}
152+
// {{ .Name }} gRPC server
153+
{{ .SvcVar }}GRPCServer := {{ .GenGRPCPkg }}.New({{ .EpVar }}, nil)
154+
{{ .GenGRPCPbPkg }}.Register{{ .StructName }}Server(grpcServer, {{ .SvcVar }}GRPCServer)
155+
{{- end }}
156+
{{- end }}
155157

156-
// 7. Start HTTP servers (graceful shutdown)
158+
reflection.Register(grpcServer)
159+
{{- end }}
160+
161+
// 7. Start servers (graceful shutdown)
157162
errc := make(chan error)
158163
go func() {
159164
c := make(chan os.Signal, 1)
@@ -167,18 +172,32 @@ func main() {
167172
go func() {
168173
defer wg.Done()
169174

175+
{{- if .HasHTTP }}
170176
go func() {
171177
log.Printf(ctx, "HTTP server listening on %s", *httpaddr)
172178
errc <- httpServer.ListenAndServe()
173179
}()
180+
{{- end }}
181+
182+
{{- if .HasGRPC }}
183+
go func() {
184+
lis, err := net.Listen("tcp", *grpcaddr)
185+
if err != nil {
186+
errc <- err
187+
return
188+
}
189+
log.Printf(ctx, "gRPC server listening on %s", *grpcaddr)
190+
errc <- grpcServer.Serve(lis)
191+
}()
192+
{{- end }}
174193

175194
go func() {
176195
log.Printf(ctx, "Metrics server listening on %s", *metricsAddr)
177196
errc <- metricsServer.ListenAndServe()
178197
}()
179198

180199
<-ctx.Done()
181-
log.Printf(ctx, "shutting down HTTP servers")
200+
log.Printf(ctx, "shutting down servers")
182201

183202
// Shutdown gracefully with a 30s timeout.
184203
sctx, scancel := context.WithTimeout(context.Background(), 30*time.Second)
@@ -192,9 +211,16 @@ func main() {
192211
}
193212
{{- end }}
194213

214+
{{- if .HasHTTP }}
195215
if err := httpServer.Shutdown(sctx); err != nil {
196216
log.Errorf(sctx, err, "failed to shutdown HTTP server")
197217
}
218+
{{- end }}
219+
220+
{{- if .HasGRPC }}
221+
grpcServer.GracefulStop()
222+
{{- end }}
223+
198224
if err := metricsServer.Shutdown(sctx); err != nil {
199225
log.Errorf(sctx, err, "failed to shutdown metrics server")
200226
}

0 commit comments

Comments
 (0)