Skip to content

Commit d16bf6b

Browse files
committed
mail: add local rate limits
1 parent 98173ae commit d16bf6b

16 files changed

Lines changed: 1823 additions & 1 deletion

internal/client/api_errors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
1515

1616
"github.com/larksuite/cli/errs"
17+
"github.com/larksuite/cli/internal/ratelimit"
1718
)
1819

1920
// rawAPIJSONHint guides users when an SDK or response body parse fails. The
@@ -30,6 +31,9 @@ func WrapDoAPIError(err error) error {
3031
if err == nil {
3132
return nil
3233
}
34+
if ratelimit.IsLocalRateLimit(err) {
35+
return err
36+
}
3337

3438
// (1) Pass-through any typed errs.* error.
3539
if _, ok := errs.ProblemOf(err); ok {

internal/client/api_errors_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,21 @@ func TestWrapDoAPIError_Nil(t *testing.T) {
196196
}
197197
}
198198

199+
func TestWrapDoAPIError_PreservesLocalMailRateLimit(t *testing.T) {
200+
original := output.ErrAPI(output.LarkErrRateLimit, "rate limited", map[string]any{
201+
"source": "local_ratelimit",
202+
"retry_after_ms": 100,
203+
})
204+
err := WrapDoAPIError(original)
205+
if err != original {
206+
t.Fatalf("WrapDoAPIError returned %p, want original %p", err, original)
207+
}
208+
var exitErr *output.ExitError
209+
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
210+
t.Fatalf("err = %v, want preserved rate_limit ExitError", err)
211+
}
212+
}
213+
199214
// ─────────────────────────────────────────────────────────────────────────────
200215
// WrapJSONResponseParseError: typed error contract.
201216
//

internal/client/client.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/larksuite/cli/internal/errclass"
2626
"github.com/larksuite/cli/internal/errcompat"
2727
"github.com/larksuite/cli/internal/output"
28+
"github.com/larksuite/cli/internal/ratelimit"
2829
"github.com/larksuite/cli/internal/util"
2930
)
3031

@@ -130,6 +131,10 @@ func (c *APIClient) buildApiReq(request RawApiRequest) (*larkcore.ApiReq, []lark
130131
func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as core.Identity, extraOpts ...larkcore.RequestOptionFunc) (*larkcore.ApiResp, error) {
131132
var opts []larkcore.RequestOptionFunc
132133

134+
if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil {
135+
return nil, err
136+
}
137+
133138
token, err := c.resolveAccessToken(ctx, as)
134139
if err != nil {
135140
// WrapDoAPIError is idempotent on already-classified errors:
@@ -166,6 +171,10 @@ func (c *APIClient) DoSDKRequest(ctx context.Context, req *larkcore.ApiReq, as c
166171
func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.Identity, opts ...Option) (*http.Response, error) {
167172
cfg := buildConfig(opts)
168173

174+
if err := ratelimit.Allow(ctx, c.rateLimitRequest(req)); err != nil {
175+
return nil, err
176+
}
177+
169178
// Resolve auth
170179
token, err := c.resolveAccessToken(ctx, as)
171180
if err != nil {
@@ -250,6 +259,21 @@ func (c *APIClient) DoStream(ctx context.Context, req *larkcore.ApiReq, as core.
250259
return resp, nil
251260
}
252261

262+
func (c *APIClient) rateLimitRequest(req *larkcore.ApiReq) ratelimit.Request {
263+
if req == nil {
264+
return ratelimit.Request{}
265+
}
266+
limitReq := ratelimit.Request{
267+
Method: req.HttpMethod,
268+
Path: req.ApiPath,
269+
}
270+
if c != nil && c.Config != nil {
271+
limitReq.Brand = c.Config.Brand
272+
limitReq.AppID = c.Config.AppID
273+
}
274+
return limitReq
275+
}
276+
253277
func streamLogID(header http.Header) string {
254278
logID := strings.TrimSpace(header.Get(larkcore.HttpHeaderKeyLogId))
255279
if logID == "" {
@@ -379,7 +403,7 @@ func (c *APIClient) paginateLoop(ctx context.Context, request RawApiRequest, opt
379403
ExtraOpts: request.ExtraOpts,
380404
})
381405
if err != nil {
382-
if page == 1 {
406+
if page == 1 || ratelimit.IsLocalRateLimit(err) {
383407
return nil, err
384408
}
385409
fmt.Fprintf(c.ErrOut, "[page %d] error, stopping pagination\n", page)

internal/client/client_test.go

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/larksuite/cli/internal/core"
2525
"github.com/larksuite/cli/internal/credential"
2626
"github.com/larksuite/cli/internal/output"
27+
"github.com/larksuite/cli/internal/ratelimit"
2728
)
2829

2930
// roundTripFunc is an adapter to use a function as http.RoundTripper.
@@ -48,6 +49,15 @@ func (s *staticTokenResolver) ResolveToken(_ context.Context, _ credential.Token
4849
return &credential.TokenResult{Token: "test-token"}, nil
4950
}
5051

52+
type countingTokenResolver struct {
53+
count int
54+
}
55+
56+
func (s *countingTokenResolver) ResolveToken(_ context.Context, _ credential.TokenSpec) (*credential.TokenResult, error) {
57+
s.count++
58+
return &credential.TokenResult{Token: "test-token"}, nil
59+
}
60+
5161
// newTestAPIClient creates an APIClient with a mock HTTP transport.
5262
func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Buffer) {
5363
t.Helper()
@@ -68,6 +78,14 @@ func newTestAPIClient(t *testing.T, rt http.RoundTripper) (*APIClient, *bytes.Bu
6878
}, errBuf
6979
}
7080

81+
func TestRateLimitRequestNilNoops(t *testing.T) {
82+
ac := &APIClient{Config: &core.CliConfig{AppID: "test-app", Brand: core.BrandFeishu}}
83+
req := ac.rateLimitRequest(nil)
84+
if req.Brand != "" || req.AppID != "" || req.Method != "" || req.Path != "" {
85+
t.Fatalf("rateLimitRequest(nil) = %#v, want empty request", req)
86+
}
87+
}
88+
7189
func TestIsJSONContentType(t *testing.T) {
7290
tests := []struct {
7391
ct string
@@ -234,6 +252,48 @@ func TestPaginateAll_PageLimitStopsPagination(t *testing.T) {
234252
}
235253
}
236254

255+
func TestPaginateAll_ReturnsMailRateLimitAfterFirstPage(t *testing.T) {
256+
now := time.Unix(100, 0)
257+
rule := ratelimit.Rule{
258+
Method: "GET",
259+
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages",
260+
Window: 2 * time.Second,
261+
Limit: 1,
262+
Scope: ratelimit.ScopeApp,
263+
}
264+
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
265+
defer restore()
266+
267+
apiCalls := 0
268+
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
269+
apiCalls++
270+
return jsonResponse(map[string]interface{}{
271+
"code": 0, "msg": "ok",
272+
"data": map[string]interface{}{
273+
"items": []interface{}{map[string]interface{}{"id": apiCalls}},
274+
"has_more": true,
275+
"page_token": "next",
276+
},
277+
}), nil
278+
}))
279+
280+
_, err := ac.PaginateAll(context.Background(), RawApiRequest{
281+
Method: "GET",
282+
URL: "/open-apis/mail/v1/user_mailboxes/me/messages",
283+
As: core.AsBot,
284+
}, PaginationOptions{PageLimit: 10, PageDelay: -1})
285+
if err == nil {
286+
t.Fatal("expected local rate limit")
287+
}
288+
var exitErr *output.ExitError
289+
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
290+
t.Fatalf("err = %v, want rate_limit ExitError", err)
291+
}
292+
if apiCalls != 1 {
293+
t.Fatalf("api calls = %d, want 1", apiCalls)
294+
}
295+
}
296+
237297
func TestPaginateAll_NaturalEndClearsPageToken(t *testing.T) {
238298
apiCalls := 0
239299
rt := roundTripFunc(func(req *http.Request) (*http.Response, error) {
@@ -464,6 +524,60 @@ func TestDoStream_TransportFailureSplitsSubtype(t *testing.T) {
464524
}
465525
}
466526

527+
func TestDoStream_MailRateLimitRunsBeforeTokenAndHTTP(t *testing.T) {
528+
now := time.Unix(100, 0)
529+
rule := ratelimit.Rule{
530+
Method: "GET",
531+
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
532+
Window: 2 * time.Second,
533+
Limit: 1,
534+
Scope: ratelimit.ScopeApp,
535+
}
536+
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
537+
defer restore()
538+
539+
httpCalls := 0
540+
ac := &APIClient{
541+
HTTP: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
542+
httpCalls++
543+
return &http.Response{
544+
StatusCode: 200,
545+
Body: io.NopCloser(strings.NewReader("ok")),
546+
}, nil
547+
})},
548+
Config: &core.CliConfig{AppID: "test-app", AppSecret: "test-secret", Brand: core.BrandFeishu},
549+
}
550+
resolver := &countingTokenResolver{}
551+
ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil)
552+
553+
newReq := func() *larkcore.ApiReq {
554+
return &larkcore.ApiReq{
555+
HttpMethod: http.MethodGet,
556+
ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1",
557+
}
558+
}
559+
resp, err := ac.DoStream(context.Background(), newReq(), core.AsBot)
560+
if err != nil {
561+
t.Fatalf("first DoStream err = %v", err)
562+
}
563+
resp.Body.Close()
564+
565+
_, err = ac.DoStream(context.Background(), newReq(), core.AsBot)
566+
if err == nil {
567+
t.Fatal("expected local rate limit")
568+
}
569+
var exitErr *output.ExitError
570+
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
571+
t.Fatalf("err = %v, want rate_limit ExitError", err)
572+
}
573+
if httpCalls != 1 {
574+
t.Fatalf("http calls = %d, want 1", httpCalls)
575+
}
576+
if resolver.count != 1 {
577+
t.Fatalf("token resolutions = %d, want 1", resolver.count)
578+
}
579+
}
580+
467581
// failingTokenResolver always returns TokenUnavailableError, exercising the
468582
// auth/credential failure path through resolveAccessToken.
469583
type failingTokenResolver struct{}
@@ -582,6 +696,81 @@ func TestDoSDKRequest_AuthFailureSurfacesTypedAuthenticationError(t *testing.T)
582696
}
583697
}
584698

699+
func TestDoSDKRequest_MailRateLimitRunsBeforeTokenAndSDK(t *testing.T) {
700+
now := time.Unix(100, 0)
701+
rule := ratelimit.Rule{
702+
Method: "GET",
703+
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
704+
Window: 2 * time.Second,
705+
Limit: 1,
706+
Scope: ratelimit.ScopeApp,
707+
}
708+
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, func() time.Time { return now }))
709+
defer restore()
710+
711+
httpCalls := 0
712+
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
713+
httpCalls++
714+
return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil
715+
}))
716+
resolver := &countingTokenResolver{}
717+
ac.Credential = credential.NewCredentialProvider(nil, nil, resolver, nil)
718+
719+
newReq := func() *larkcore.ApiReq {
720+
return &larkcore.ApiReq{
721+
HttpMethod: http.MethodGet,
722+
ApiPath: "/open-apis/mail/v1/user_mailboxes/me/messages/msg_1",
723+
}
724+
}
725+
if _, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot); err != nil {
726+
t.Fatalf("first DoSDKRequest err = %v", err)
727+
}
728+
_, err := ac.DoSDKRequest(context.Background(), newReq(), core.AsBot)
729+
if err == nil {
730+
t.Fatal("expected local rate limit")
731+
}
732+
var exitErr *output.ExitError
733+
if !errors.As(err, &exitErr) || exitErr.Detail == nil || exitErr.Detail.Type != "rate_limit" {
734+
t.Fatalf("err = %v, want rate_limit ExitError", err)
735+
}
736+
if httpCalls != 1 {
737+
t.Fatalf("http calls = %d, want 1", httpCalls)
738+
}
739+
if resolver.count != 1 {
740+
t.Fatalf("token resolutions = %d, want 1", resolver.count)
741+
}
742+
}
743+
744+
func TestDoSDKRequest_NonMailAndUnconfiguredMailStillSend(t *testing.T) {
745+
rule := ratelimit.Rule{
746+
Method: "GET",
747+
CanonicalPath: "/open-apis/mail/v1/user_mailboxes/:user_mailbox_id/messages/:message_id",
748+
Window: time.Second,
749+
Limit: 1,
750+
Scope: ratelimit.ScopeApp,
751+
}
752+
restore := ratelimit.SetDefaultLimiterForTest(ratelimit.NewLimiterForDir(t.TempDir(), []ratelimit.Rule{rule}, time.Now))
753+
defer restore()
754+
755+
httpCalls := 0
756+
ac, _ := newTestAPIClient(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
757+
httpCalls++
758+
return jsonResponse(map[string]interface{}{"code": 0, "msg": "ok"}), nil
759+
}))
760+
761+
for _, path := range []string{
762+
"/open-apis/contact/v3/users/u1",
763+
"/open-apis/mail/v1/user_mailboxes/me/settings",
764+
} {
765+
if _, err := ac.DoSDKRequest(context.Background(), &larkcore.ApiReq{HttpMethod: http.MethodGet, ApiPath: path}, core.AsBot); err != nil {
766+
t.Fatalf("DoSDKRequest(%s) err = %v", path, err)
767+
}
768+
}
769+
if httpCalls != 2 {
770+
t.Fatalf("http calls = %d, want 2", httpCalls)
771+
}
772+
}
773+
585774
// TestDoSDKRequest_TransportFailureWrapsAsNetwork pins that genuinely untyped
586775
// SDK transport errors get the typed network classification via WrapDoAPIError.
587776
// io.ErrUnexpectedEOF from a RoundTripper surfaces through net/http as a

0 commit comments

Comments
 (0)