Skip to content

Commit a6fea16

Browse files
Patel230Race Tester
andauthored
test(workflow): cover retry cancellation paths (#21)
Co-authored-by: Race Tester <race@test.com>
1 parent d73540e commit a6fea16

9 files changed

Lines changed: 258 additions & 27 deletions

File tree

AGENTS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ c := New(
6767
```
6868

6969
### Error Hierarchy
70-
`errors.go` defines `APIError` as the base struct with `StatusCode`, `Code`, `Message`, `Details`. Each HTTP status gets a wrapper type (`NotFoundError`, `RateLimitError`, etc.) that embeds `APIError`. All error types implement `Error()` and `Unwrap()` for `errors.Is`/`errors.As` compatibility:
70+
`errors.go` defines `APIError` as the base struct with `StatusCode`, `Code`, `Message`, `Details`. Each HTTP status gets a wrapper type (`NotFoundError`, `RateLimitError`, etc.) that embeds `APIError`. Subtypes inherit `Error()` via embedding (only `RateLimitError` overrides it to append retry info) and implement `Unwrap()` for `errors.Is`/`errors.As` compatibility:
7171
```go
7272
var notFound *NotFoundError
7373
if errors.As(err, &notFound) {
@@ -145,7 +145,7 @@ if !errors.As(err, &notFound) { t.Error("expected NotFoundError") }
145145
- **Do not change**: `APIError.Error()` format string — tests assert exact string output
146146
- **Do not change**: JSON struct tags — they match the daemon's API contract
147147
- **Do not change**: `Unwrap()` implementations — `errors.As` depends on them for type matching
148-
- **Safe to extend**: add new error types by creating a struct embedding `APIError`, adding to `parseAPIError` switch, and implementing `Error()` + `Unwrap()`
148+
- **Safe to extend**: add new error types by creating a struct embedding `APIError` (which promotes `Error()`), adding to `parseAPIError` switch, and implementing `Unwrap()`
149149
- **Safe to extend**: add new client methods by following the `get()`/`post()` delegation pattern
150150
- **When adding streaming endpoints**: follow `ChatStream` pattern — set `Accept: text/event-stream`, check status before wrapping in `newStreamReader`
151151
- **Concurrency**: any new mutable state on `Agent` must be protected by `a.mu`

agent.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ type MemoryConfig struct {
3838

3939
// Agent wraps a Client with declarative configuration, providing a
4040
// simplified interface for conversational AI interactions.
41+
//
42+
// Concurrency: Agent is safe for concurrent use. The session ID is read and
43+
// updated under an internal mutex. Each Chat or ChatStream call captures the
44+
// session ID at the moment the request is built; a stream returned by
45+
// ChatStream continues to use the session ID captured at call time even if a
46+
// concurrent Chat call establishes a new session while the stream is being
47+
// consumed.
4148
type Agent struct {
4249
client *Client
4350
config AgentConfig
@@ -94,7 +101,15 @@ func (a *Agent) Chat(ctx context.Context, message string) (*ChatResponse, error)
94101
// ChatStream sends a message and returns a streaming response reader.
95102
// Note: streaming with tools is not automatically looped; use Chat for
96103
// full tool loop support.
104+
//
105+
// The session ID is captured under the agent's lock when the request is
106+
// built, so the entire stream lifecycle uses that snapshot: a concurrent
107+
// Chat call that mutates the agent's session ID does not affect an
108+
// in-flight stream.
97109
func (a *Agent) ChatStream(ctx context.Context, message string) (*StreamReader, error) {
110+
// Capture the session ID into the request under the lock. req holds the
111+
// captured value by copy, so the stream is immune to later mutations of
112+
// a.sessionID by concurrent Chat calls.
98113
a.mu.Lock()
99114
req := a.buildRequest(message)
100115
a.mu.Unlock()

agent_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"net/http"
77
"net/http/httptest"
8+
"sync"
89
"sync/atomic"
910
"testing"
1011
)
@@ -197,6 +198,67 @@ func TestAgent_ChatStream(t *testing.T) {
197198
}
198199
}
199200

201+
// TestAgent_ConcurrentChatAndStream runs Chat and ChatStream concurrently
202+
// (run with -race) to verify that ChatStream's session ID snapshot is not
203+
// affected by concurrent Chat calls mutating a.sessionID, and that no data
204+
// race exists between request building and session updates.
205+
func TestAgent_ConcurrentChatAndStream(t *testing.T) {
206+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
207+
if r.Header.Get("Accept") == "text/event-stream" {
208+
w.Header().Set("Content-Type", "text/event-stream")
209+
w.WriteHeader(http.StatusOK)
210+
w.Write([]byte("data: chunk\n\n"))
211+
w.Write([]byte("event: done\ndata: {}\n\n"))
212+
return
213+
}
214+
json.NewEncoder(w).Encode(ChatResponse{
215+
SessionID: "race-sess",
216+
Response: "ok",
217+
})
218+
}))
219+
defer srv.Close()
220+
221+
c := New(WithBaseURL(srv.URL))
222+
agent := NewAgent(c, AgentConfig{Model: "test"})
223+
224+
const iterations = 10
225+
var wg sync.WaitGroup
226+
errs := make(chan error, iterations*2)
227+
228+
for i := 0; i < iterations; i++ {
229+
wg.Add(2)
230+
go func() {
231+
defer wg.Done()
232+
if _, err := agent.Chat(context.Background(), "hello"); err != nil {
233+
errs <- err
234+
}
235+
}()
236+
go func() {
237+
defer wg.Done()
238+
stream, err := agent.ChatStream(context.Background(), "stream hello")
239+
if err != nil {
240+
errs <- err
241+
return
242+
}
243+
defer stream.Close()
244+
// Consume the whole stream while Chat calls mutate sessionID.
245+
if _, err := stream.CollectText(context.Background()); err != nil {
246+
errs <- err
247+
}
248+
}()
249+
}
250+
wg.Wait()
251+
close(errs)
252+
253+
for err := range errs {
254+
t.Errorf("concurrent Chat/ChatStream error: %v", err)
255+
}
256+
257+
if got := agent.SessionID(); got != "race-sess" {
258+
t.Errorf("SessionID = %q, want %q", got, "race-sess")
259+
}
260+
}
261+
200262
func TestNewAgent_Defaults(t *testing.T) {
201263
c := New()
202264
agent := NewAgent(c, AgentConfig{})

client.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ func WithAPIKey(key string) ClientOption {
4242
}
4343

4444
// New creates a new hawk SDK client.
45+
//
46+
// Note: the client performs no retries by default. Pass
47+
// WithRetry(DefaultRetryConfig()) for production use to enable automatic
48+
// retries with exponential backoff on transient failures.
4549
func New(opts ...ClientOption) *Client {
4650
c := &Client{
4751
baseURL: defaultBaseURL,
@@ -144,7 +148,10 @@ func (c *Client) DeleteSession(ctx context.Context, id string) error {
144148
}
145149
defer func() { _ = resp.Body.Close() }()
146150

147-
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
151+
// The daemon returns 204 No Content on delete, but older daemon versions
152+
// and intermediary proxies may respond with 200 OK instead. Accepting any
153+
// 2xx keeps this defensive and consistent with post()'s success handling.
154+
if resp.StatusCode/100 != 2 {
148155
return parseAPIError(resp)
149156
}
150157
return nil
@@ -226,7 +233,9 @@ func (c *Client) post(ctx context.Context, path string, body interface{}, out in
226233
}
227234
defer func() { _ = resp.Body.Close() }()
228235

229-
if resp.StatusCode != http.StatusOK {
236+
// Accept any 2xx status: creation endpoints may return 201 Created
237+
// and future endpoints may use other success codes.
238+
if resp.StatusCode/100 != 2 {
230239
return parseAPIError(resp)
231240
}
232241

docs/architecture.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,16 @@ c := hawksdk.New(
5252
health, err := c.Health(ctx)
5353

5454
// 💬 Non-streaming chat
55-
resp, err := c.Chat(ctx, hawksdk.ChatRequest{Message: "list files"})
55+
resp, err := c.Chat(ctx, hawksdk.ChatRequest{Prompt: "list files"})
5656

5757
// 📡 Streaming chat
58-
stream, err := c.ChatStream(ctx, hawksdk.ChatRequest{Message: "explain this code"})
58+
stream, err := c.ChatStream(ctx, hawksdk.ChatRequest{Prompt: "explain this code"})
5959
defer stream.Close()
6060
for { ev, err := stream.Next(); if err != nil { break }; fmt.Print(ev.Data) }
6161

6262
// 📋 Sessions
63-
sessions, _ := c.Sessions(ctx, hawksdk.ListOptions{Limit: 10})
64-
msgs, _ := c.Messages(ctx, sessionID, hawksdk.ListOptions{})
63+
sessions, _ := c.Sessions(ctx, &hawksdk.ListOptions{Limit: 10})
64+
msgs, _ := c.Messages(ctx, sessionID, nil)
6565
_ = c.DeleteSession(ctx, sessionID)
6666

6767
// 📊 Stats
@@ -73,7 +73,7 @@ stats, _ := c.Stats(ctx)
7373
## 🤖 Agent (Higher-Level)
7474

7575
```go
76-
agent := hawksdk.NewAgent(c, hawksdk.AgentConfig{SystemPrompt: "You are a Go expert"})
76+
agent := hawksdk.NewAgent(c, hawksdk.AgentConfig{Model: "claude-sonnet-4-5", MaxRounds: 5})
7777
resp, _ := agent.Chat(ctx, "refactor this function")
7878
// Subsequent calls automatically continue the same session
7979
```

errors.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ type BadRequestError struct {
3535
APIError
3636
}
3737

38-
// Error implements the error interface.
39-
func (e *BadRequestError) Error() string { return e.APIError.Error() }
40-
4138
// Unwrap allows errors.Is/As to match the underlying APIError.
4239
func (e *BadRequestError) Unwrap() error { return &e.APIError }
4340

@@ -46,9 +43,6 @@ type AuthenticationError struct {
4643
APIError
4744
}
4845

49-
// Error implements the error interface.
50-
func (e *AuthenticationError) Error() string { return e.APIError.Error() }
51-
5246
// Unwrap allows errors.Is/As to match the underlying APIError.
5347
func (e *AuthenticationError) Unwrap() error { return &e.APIError }
5448

@@ -57,9 +51,6 @@ type ForbiddenError struct {
5751
APIError
5852
}
5953

60-
// Error implements the error interface.
61-
func (e *ForbiddenError) Error() string { return e.APIError.Error() }
62-
6354
// Unwrap allows errors.Is/As to match the underlying APIError.
6455
func (e *ForbiddenError) Unwrap() error { return &e.APIError }
6556

@@ -68,9 +59,6 @@ type NotFoundError struct {
6859
APIError
6960
}
7061

71-
// Error implements the error interface.
72-
func (e *NotFoundError) Error() string { return e.APIError.Error() }
73-
7462
// Unwrap allows errors.Is/As to match the underlying APIError.
7563
func (e *NotFoundError) Unwrap() error { return &e.APIError }
7664

@@ -98,9 +86,6 @@ type InternalServerError struct {
9886
APIError
9987
}
10088

101-
// Error implements the error interface.
102-
func (e *InternalServerError) Error() string { return e.APIError.Error() }
103-
10489
// Unwrap allows errors.Is/As to match the underlying APIError.
10590
func (e *InternalServerError) Unwrap() error { return &e.APIError }
10691

@@ -109,9 +94,6 @@ type ServiceUnavailableError struct {
10994
APIError
11095
}
11196

112-
// Error implements the error interface.
113-
func (e *ServiceUnavailableError) Error() string { return e.APIError.Error() }
114-
11597
// Unwrap allows errors.Is/As to match the underlying APIError.
11698
func (e *ServiceUnavailableError) Unwrap() error { return &e.APIError }
11799

sessions_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,29 @@ func TestCreateSession(t *testing.T) {
336336
}
337337
}
338338

339+
// TestCreateSession201 verifies that post() accepts any 2xx status, not
340+
// just 200 OK — creation endpoints commonly return 201 Created.
341+
func TestCreateSession201(t *testing.T) {
342+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
343+
w.Header().Set("Content-Type", "application/json")
344+
w.WriteHeader(http.StatusCreated)
345+
json.NewEncoder(w).Encode(SessionSummary{
346+
ID: "created-sess",
347+
CWD: "/tmp",
348+
})
349+
}))
350+
defer srv.Close()
351+
352+
c := New(WithBaseURL(srv.URL))
353+
resp, err := c.CreateSession(context.Background(), CreateSessionRequest{Name: "n"})
354+
if err != nil {
355+
t.Fatalf("CreateSession() with 201 response error: %v", err)
356+
}
357+
if resp.ID != "created-sess" {
358+
t.Errorf("ID = %q, want %q", resp.ID, "created-sess")
359+
}
360+
}
361+
339362
func TestCreateSessionEmptyBody(t *testing.T) {
340363
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341364
var req CreateSessionRequest

stream_helpers.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,22 @@ type ToolCallDelta struct {
5252

5353
// CollectText consumes the entire stream and returns the concatenated text content.
5454
// It blocks until the stream ends or the context is cancelled.
55+
//
56+
// When the returned error is non-nil, the returned string may be a partial
57+
// result: it contains all text collected up to the point the error occurred,
58+
// which callers may use or discard as appropriate. The error returned is the
59+
// first error encountered while consuming the stream — if the stream emitted
60+
// an "error" event before a later read failure, the "error" event wins.
5561
func (sr *StreamReader) CollectText(ctx context.Context) (string, error) {
5662
var sb strings.Builder
5763
var firstErr error
5864

5965
for {
6066
select {
6167
case <-ctx.Done():
68+
if firstErr != nil {
69+
return sb.String(), firstErr
70+
}
6271
return sb.String(), ctx.Err()
6372
default:
6473
}
@@ -68,6 +77,11 @@ func (sr *StreamReader) CollectText(ctx context.Context) (string, error) {
6877
return sb.String(), firstErr
6978
}
7079
if err != nil {
80+
// Preserve first-error semantics: an "error" event seen earlier
81+
// takes precedence over a subsequent read failure.
82+
if firstErr != nil {
83+
return sb.String(), firstErr
84+
}
7185
return sb.String(), err
7286
}
7387

0 commit comments

Comments
 (0)