Skip to content

Commit 3604ae1

Browse files
committed
Allow generating only async or sync code (connectrpc#214)
Signed-off-by: Anuraag Agrawal <anuraaga@gmail.com>
1 parent bcc604a commit 3604ae1

6 files changed

Lines changed: 259 additions & 10 deletions

File tree

protoc-gen-connect-python/generator/config.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ type Config struct {
3434

3535
// Imports is how to import dependencies in the generated code.
3636
Imports Imports
37+
38+
// Async indicates whether to only generate asynchronous code. If false,
39+
// only synchronous code will be generated. If nil, both synchronous and
40+
// asynchronous code will be generated.
41+
Async *bool
3742
}
3843

3944
func parseConfig(p string) Config {
@@ -64,6 +69,15 @@ func parseConfig(p string) Config {
6469
case "relative":
6570
cfg.Imports = ImportsRelative
6671
}
72+
case "async":
73+
switch value {
74+
case "true":
75+
trueVal := true
76+
cfg.Async = &trueVal
77+
case "false":
78+
falseVal := false
79+
cfg.Async = &falseVal
80+
}
6781
}
6882
}
6983
return cfg

protoc-gen-connect-python/generator/generator.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ func generateConnectFile(fd protoreflect.FileDescriptor, conf Config) (string, s
5555
ModuleName: moduleName,
5656
Imports: importStatements(fd, conf),
5757
}
58+
if conf.Async != nil {
59+
if *conf.Async {
60+
vars.SkipSync = true
61+
} else {
62+
vars.SkipAsync = true
63+
}
64+
}
5865

5966
svcs := fd.Services()
6067
packageName := string(fd.Package())
@@ -109,7 +116,7 @@ func generateConnectFile(fd protoreflect.FileDescriptor, conf Config) (string, s
109116
vars.Services = append(vars.Services, connectSvc)
110117
}
111118

112-
var buf = &bytes.Buffer{}
119+
buf := &bytes.Buffer{}
113120
err := ConnectTemplate.Execute(buf, vars)
114121
if err != nil {
115122
return "", "", fmt.Errorf("failed to execute template: %w", err)

protoc-gen-connect-python/generator/generator_test.go

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ func TestGenerate(t *testing.T) {
127127
t.Parallel()
128128

129129
tests := []struct {
130-
name string
131-
req *pluginpb.CodeGeneratorRequest
132-
wantStrings []string
133-
wantErr bool
130+
name string
131+
req *pluginpb.CodeGeneratorRequest
132+
wantStrings []string
133+
dontWantStrings []string
134+
wantErr bool
134135
}{
135136
{
136137
name: "empty request",
@@ -195,7 +196,127 @@ func TestGenerate(t *testing.T) {
195196
},
196197
},
197198
wantErr: false,
198-
wantStrings: []string{"def try_(self"},
199+
wantStrings: []string{"class TestServiceASGIApplication", "class TestServiceWSGIApplication"},
200+
},
201+
{
202+
name: "async only",
203+
req: &pluginpb.CodeGeneratorRequest{
204+
FileToGenerate: []string{"test.proto"},
205+
Parameter: proto.String("async=true"),
206+
ProtoFile: []*descriptorpb.FileDescriptorProto{
207+
{
208+
Name: proto.String("test.proto"),
209+
Package: proto.String("test"),
210+
Dependency: []string{"other.proto"},
211+
Service: []*descriptorpb.ServiceDescriptorProto{
212+
{
213+
Name: proto.String("TestService"),
214+
Method: []*descriptorpb.MethodDescriptorProto{
215+
{
216+
Name: proto.String("TestMethod"),
217+
InputType: proto.String(".test.TestRequest"),
218+
OutputType: proto.String(".test.TestResponse"),
219+
},
220+
{
221+
Name: proto.String("TestMethod2"),
222+
InputType: proto.String(".otherpackage.OtherRequest"),
223+
OutputType: proto.String(".otherpackage.OtherResponse"),
224+
},
225+
// Reserved keyword
226+
{
227+
Name: proto.String("Try"),
228+
InputType: proto.String(".otherpackage.OtherRequest"),
229+
OutputType: proto.String(".otherpackage.OtherResponse"),
230+
},
231+
},
232+
},
233+
},
234+
MessageType: []*descriptorpb.DescriptorProto{
235+
{
236+
Name: proto.String("TestRequest"),
237+
},
238+
{
239+
Name: proto.String("TestResponse"),
240+
},
241+
},
242+
},
243+
{
244+
Name: proto.String("other.proto"),
245+
Package: proto.String("otherpackage"),
246+
MessageType: []*descriptorpb.DescriptorProto{
247+
{
248+
Name: proto.String("OtherRequest"),
249+
},
250+
{
251+
Name: proto.String("OtherResponse"),
252+
},
253+
},
254+
},
255+
},
256+
},
257+
wantErr: false,
258+
wantStrings: []string{"class TestServiceASGIApplication"},
259+
dontWantStrings: []string{"class TestServiceWSGIApplication"},
260+
},
261+
{
262+
name: "sync only",
263+
req: &pluginpb.CodeGeneratorRequest{
264+
FileToGenerate: []string{"test.proto"},
265+
Parameter: proto.String("async=false"),
266+
ProtoFile: []*descriptorpb.FileDescriptorProto{
267+
{
268+
Name: proto.String("test.proto"),
269+
Package: proto.String("test"),
270+
Dependency: []string{"other.proto"},
271+
Service: []*descriptorpb.ServiceDescriptorProto{
272+
{
273+
Name: proto.String("TestService"),
274+
Method: []*descriptorpb.MethodDescriptorProto{
275+
{
276+
Name: proto.String("TestMethod"),
277+
InputType: proto.String(".test.TestRequest"),
278+
OutputType: proto.String(".test.TestResponse"),
279+
},
280+
{
281+
Name: proto.String("TestMethod2"),
282+
InputType: proto.String(".otherpackage.OtherRequest"),
283+
OutputType: proto.String(".otherpackage.OtherResponse"),
284+
},
285+
// Reserved keyword
286+
{
287+
Name: proto.String("Try"),
288+
InputType: proto.String(".otherpackage.OtherRequest"),
289+
OutputType: proto.String(".otherpackage.OtherResponse"),
290+
},
291+
},
292+
},
293+
},
294+
MessageType: []*descriptorpb.DescriptorProto{
295+
{
296+
Name: proto.String("TestRequest"),
297+
},
298+
{
299+
Name: proto.String("TestResponse"),
300+
},
301+
},
302+
},
303+
{
304+
Name: proto.String("other.proto"),
305+
Package: proto.String("otherpackage"),
306+
MessageType: []*descriptorpb.DescriptorProto{
307+
{
308+
Name: proto.String("OtherRequest"),
309+
},
310+
{
311+
Name: proto.String("OtherResponse"),
312+
},
313+
},
314+
},
315+
},
316+
},
317+
wantErr: false,
318+
wantStrings: []string{"class TestServiceWSGIApplication"},
319+
dontWantStrings: []string{"class TestServiceASGIApplication"},
199320
},
200321
}
201322

@@ -219,6 +340,11 @@ func TestGenerate(t *testing.T) {
219340
t.Errorf("generate() missing expected string: %v", s)
220341
}
221342
}
343+
for _, s := range tt.dontWantStrings {
344+
if strings.Contains(resp.GetFile()[0].GetContent(), s) {
345+
t.Errorf("generate() contains unexpected string: %v", s)
346+
}
347+
}
222348
}
223349
})
224350
}

protoc-gen-connect-python/generator/template.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ type ConnectTemplateVariables struct {
1313
ModuleName string
1414
Imports []ImportStatement
1515
Services []*ConnectService
16+
SkipAsync bool
17+
SkipSync bool
1618
}
1719

1820
type ConnectService struct {
@@ -61,9 +63,9 @@ from connectrpc.server import ConnectASGIApplication, ConnectWSGIApplication, En
6163
{{if .Relative}}from . import {{.Name}}{{else}}import {{.Name}}{{end}} as {{.Alias}}
6264
{{- end}}
6365
{{- end}}
64-
{{- range .Services}}
65-
6666
67+
{{if not .SkipAsync }}
68+
{{- range .Services}}
6769
class {{.Name}}(Protocol):{{- range .Methods }}
6870
{{if not .ResponseStream }}async {{end}}def {{.PythonName}}(self, request: {{if .RequestStream}}AsyncIterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}AsyncIterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}:
6971
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")
@@ -124,6 +126,9 @@ class {{.Name}}Client(ConnectClient):{{range .Methods}}
124126
{{- end}}
125127
)
126128
{{end}}{{- end }}
129+
{{end}}
130+
131+
{{if not .SkipSync }}
127132
{{range .Services}}
128133
class {{.Name}}Sync(Protocol):{{- range .Methods }}
129134
def {{.PythonName}}(self, request: {{if .RequestStream}}Iterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}Iterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}:
@@ -184,4 +189,6 @@ class {{.Name}}ClientSync(ConnectClientSync):{{range .Methods}}
184189
use_get=use_get,
185190
{{- end}}
186191
)
187-
{{end}}{{end}}`))
192+
{{end}}{{end}}
193+
{{end}}
194+
`))

src/connectrpc/_server_async.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ async def __call__(
232232
ctx,
233233
)
234234
except Exception as e:
235-
return await self._handle_error(e, ctx, send)
235+
await self._handle_error(e, ctx, send)
236+
if not isinstance(e, ConnectError):
237+
raise
238+
return None
236239

237240
# Streams have their own error handling so move out of the try block.
238241
return await self._handle_stream(
@@ -486,6 +489,8 @@ async def _watch_for_disconnect() -> None:
486489
"more_trailers": False,
487490
}
488491
)
492+
if error and not isinstance(error, ConnectError):
493+
raise error
489494

490495
async def _handle_error(
491496
self, exc: Exception, ctx: RequestContext | None, send: ASGISendCallable

test/test_errors.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import threading
5+
from collections.abc import AsyncIterator
56
from http import HTTPStatus
67
from typing import NoReturn
78

@@ -21,6 +22,7 @@
2122

2223
from connectrpc.code import Code
2324
from connectrpc.errors import ConnectError
25+
from connectrpc.request import RequestContext
2426

2527
from .haberdasher_connect import (
2628
Haberdasher,
@@ -424,3 +426,91 @@ async def make_hat(self, request, ctx) -> NoReturn:
424426
assert exc_info.value.code == Code.DEADLINE_EXCEEDED
425427
assert exc_info.value.message == "Request timed out"
426428
assert recorded_timeout_header == "200"
429+
430+
431+
@pytest.mark.asyncio
432+
async def test_async_unhandled_exception_reraised() -> None:
433+
class RaisingHaberdasher(Haberdasher):
434+
async def make_hat(self, request, ctx) -> NoReturn:
435+
raise TypeError("Something went wrong")
436+
437+
app = HaberdasherASGIApplication(RaisingHaberdasher())
438+
transport = ASGITransport(app)
439+
http_client = Client(transport)
440+
441+
async with HaberdasherClient(
442+
"http://localhost", timeout_ms=200, http_client=http_client
443+
) as client:
444+
with pytest.raises(ConnectError, match="Something went wrong"):
445+
await client.make_hat(request=Size(inches=10))
446+
447+
assert isinstance(transport.app_exception, TypeError)
448+
assert str(transport.app_exception) == "Something went wrong"
449+
450+
451+
@pytest.mark.asyncio
452+
async def test_async_unhandled_exception_reraised_stream() -> None:
453+
class RaisingHaberdasher(Haberdasher):
454+
def make_similar_hats(
455+
self, request: Size, ctx: RequestContext
456+
) -> AsyncIterator[Hat]:
457+
raise TypeError("Something went wrong")
458+
459+
app = HaberdasherASGIApplication(RaisingHaberdasher())
460+
transport = ASGITransport(app)
461+
http_client = Client(transport)
462+
463+
async with HaberdasherClient(
464+
"http://localhost", timeout_ms=200, http_client=http_client
465+
) as client:
466+
with pytest.raises(ConnectError, match="Something went wrong"):
467+
async for _ in client.make_similar_hats(request=Size(inches=10)):
468+
pass
469+
470+
assert isinstance(transport.app_exception, TypeError)
471+
assert str(transport.app_exception) == "Something went wrong"
472+
473+
474+
@pytest.mark.asyncio
475+
async def test_async_connect_exception_not_reraised() -> None:
476+
class RaisingHaberdasher(Haberdasher):
477+
async def make_hat(self, request, ctx) -> NoReturn:
478+
raise ConnectError(Code.INTERNAL, "We're broken")
479+
480+
app = HaberdasherASGIApplication(RaisingHaberdasher())
481+
transport = ASGITransport(app)
482+
http_client = Client(transport)
483+
484+
async with HaberdasherClient(
485+
"http://localhost", timeout_ms=200, http_client=http_client
486+
) as client:
487+
with pytest.raises(ConnectError, match="We're broken"):
488+
await client.make_hat(request=Size(inches=10))
489+
490+
# Workaround https://github.com/curioswitch/pyqwest/pull/148
491+
# TODO: Remove after fix is released
492+
assert getattr(transport, "_app_exception", None) is None
493+
494+
495+
@pytest.mark.asyncio
496+
async def test_async_connect_exception_not_reraised_stream() -> None:
497+
class RaisingHaberdasher(Haberdasher):
498+
def make_similar_hats(
499+
self, request: Size, ctx: RequestContext
500+
) -> AsyncIterator[Hat]:
501+
raise ConnectError(Code.INTERNAL, "We're broken")
502+
503+
app = HaberdasherASGIApplication(RaisingHaberdasher())
504+
transport = ASGITransport(app)
505+
http_client = Client(transport)
506+
507+
async with HaberdasherClient(
508+
"http://localhost", timeout_ms=200, http_client=http_client
509+
) as client:
510+
with pytest.raises(ConnectError, match="We're broken"):
511+
async for _ in client.make_similar_hats(request=Size(inches=10)):
512+
pass
513+
514+
# Workaround https://github.com/curioswitch/pyqwest/pull/148
515+
# TODO: Remove after fix is released
516+
assert getattr(transport, "_app_exception", None) is None

0 commit comments

Comments
 (0)