Skip to content

Commit cd9e5e5

Browse files
authored
Merge pull request cli#12068 from cli/kw/spike-custom-agents
`gh agent-task create`: support `--custom-agent`/`-a` flag
2 parents a701a37 + 44653c5 commit cd9e5e5

6 files changed

Lines changed: 114 additions & 36 deletions

File tree

pkg/cmd/agent-task/capi/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ const capiHost = "api.githubcopilot.com"
1616
// may be replaced with test doubles in unit tests.
1717
type CapiClient interface {
1818
ListLatestSessionsForViewer(ctx context.Context, limit int) ([]*Session, error)
19-
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error)
19+
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string, customAgent string) (*Job, error)
2020
GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error)
2121
GetSession(ctx context.Context, id string) (*Session, error)
2222
GetSessionLogs(ctx context.Context, id string) ([]byte, error)

pkg/cmd/agent-task/capi/client_mock.go

Lines changed: 10 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/cmd/agent-task/capi/job.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"io"
910
"net/http"
1011
"net/url"
1112
"time"
@@ -18,6 +19,7 @@ type Job struct {
1819
ID string `json:"job_id,omitempty"`
1920
SessionID string `json:"session_id,omitempty"`
2021
ProblemStatement string `json:"problem_statement,omitempty"`
22+
CustomAgent string `json:"custom_agent,omitempty"`
2123
EventType string `json:"event_type,omitempty"`
2224
ContentFilterMode string `json:"content_filter_mode,omitempty"`
2325
Status string `json:"status,omitempty"`
@@ -54,7 +56,7 @@ const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
5456
// CreateJob queues a new job using the v1 Jobs API. It may or may not
5557
// return Pull Request information. If Pull Request information is required
5658
// following up by polling GetJob with the job ID is necessary.
57-
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) {
59+
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*Job, error) {
5860
if owner == "" || repo == "" {
5961
return nil, errors.New("owner and repo are required")
6062
}
@@ -71,6 +73,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
7173

7274
payload := &Job{
7375
ProblemStatement: problemStatement,
76+
CustomAgent: customAgent,
7477
EventType: defaultEventType,
7578
PullRequest: &prOpts,
7679
}
@@ -88,8 +91,10 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
8891
}
8992
defer res.Body.Close()
9093

94+
body, _ := io.ReadAll(res.Body)
95+
9196
var j Job
92-
if err := json.NewDecoder(res.Body).Decode(&j); err != nil {
97+
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&j); err != nil {
9398
if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
9499
// This happens when there's an error like unauthorized (401).
95100
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
@@ -99,11 +104,22 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
99104
}
100105

101106
if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
107+
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
108+
109+
// If the response has error embeded, we can use that.
110+
// TODO: Does this really ever happen?
102111
if j.ErrorInfo != nil {
103-
return nil, fmt.Errorf("failed to create job: %s", j.ErrorInfo.Message)
112+
return nil, fmt.Errorf("failed to create job: %s: %s", statusText, j.ErrorInfo.Message)
104113
}
105-
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
106-
return nil, fmt.Errorf("failed to create job: %s", statusText)
114+
115+
// If the response doesn't have error embedded,
116+
// try to decode the response itself as a jobError.
117+
var errInfo JobError
118+
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&errInfo); err != nil {
119+
return nil, fmt.Errorf("failed to create job: %s", statusText)
120+
}
121+
122+
return nil, fmt.Errorf("failed to create job: %s: %s", statusText, errInfo.Message)
107123
}
108124

109125
return &j, nil

pkg/cmd/agent-task/capi/job_test.go

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,14 @@ func TestGetJob(t *testing.T) {
188188
func TestCreateJobRequiresRepoAndProblemStatement(t *testing.T) {
189189
client := &CAPIClient{}
190190

191-
_, err := client.CreateJob(context.Background(), "", "only-repo", "", "")
191+
_, err := client.CreateJob(context.Background(), "", "only-repo", "", "", "")
192192
assert.EqualError(t, err, "owner and repo are required")
193-
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "")
193+
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "", "")
194194
assert.EqualError(t, err, "owner and repo are required")
195-
_, err = client.CreateJob(context.Background(), "", "", "", "")
195+
_, err = client.CreateJob(context.Background(), "", "", "", "", "")
196196
assert.EqualError(t, err, "owner and repo are required")
197197

198-
_, err = client.CreateJob(context.Background(), "owner", "repo", "", "")
198+
_, err = client.CreateJob(context.Background(), "owner", "repo", "", "", "")
199199
assert.EqualError(t, err, "problem statement is required")
200200
}
201201

@@ -205,11 +205,12 @@ func TestCreateJob(t *testing.T) {
205205
require.NoError(t, err)
206206

207207
tests := []struct {
208-
name string
209-
baseBranch string
210-
httpStubs func(*testing.T, *httpmock.Registry)
211-
wantErr string
212-
wantOut *Job
208+
name string
209+
baseBranch string
210+
customAgent string
211+
httpStubs func(*testing.T, *httpmock.Registry)
212+
wantErr string
213+
wantOut *Job
213214
}{
214215
{
215216
name: "success",
@@ -305,6 +306,56 @@ func TestCreateJob(t *testing.T) {
305306
UpdatedAt: sampleDate,
306307
},
307308
},
309+
{
310+
name: "Success with custom agent",
311+
customAgent: "my-custom-agent",
312+
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
313+
reg.Register(
314+
httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"),
315+
httpmock.RESTPayload(201,
316+
heredoc.Docf(`
317+
{
318+
"job_id": "job123",
319+
"session_id": "sess1",
320+
"problem_statement": "Do the thing",
321+
"custom_agent": "my-custom-agent",
322+
"event_type": "foo",
323+
"content_filter_mode": "foo",
324+
"status": "foo",
325+
"result": "foo",
326+
"actor": {
327+
"id": 1,
328+
"login": "octocat"
329+
},
330+
"created_at": "%[1]s",
331+
"updated_at": "%[1]s"
332+
}
333+
`, sampleDateString),
334+
func(payload map[string]interface{}) {
335+
assert.Equal(t, "Do the thing", payload["problem_statement"])
336+
assert.Equal(t, "gh_cli", payload["event_type"])
337+
assert.Equal(t, "my-custom-agent", payload["custom_agent"])
338+
},
339+
),
340+
)
341+
},
342+
wantOut: &Job{
343+
ID: "job123",
344+
SessionID: "sess1",
345+
ProblemStatement: "Do the thing",
346+
CustomAgent: "my-custom-agent",
347+
EventType: "foo",
348+
ContentFilterMode: "foo",
349+
Status: "foo",
350+
Result: "foo",
351+
Actor: &JobActor{
352+
ID: 1,
353+
Login: "octocat",
354+
},
355+
CreatedAt: sampleDate,
356+
UpdatedAt: sampleDate,
357+
},
358+
},
308359
{
309360
name: "API error, included in response body",
310361
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
@@ -317,7 +368,7 @@ func TestCreateJob(t *testing.T) {
317368
}`)),
318369
)
319370
},
320-
wantErr: "failed to create job: some error",
371+
wantErr: "failed to create job: 500 Internal Server Error: some error",
321372
},
322373
{
323374
name: "API error",
@@ -327,7 +378,7 @@ func TestCreateJob(t *testing.T) {
327378
httpmock.StatusStringResponse(500, `{}`),
328379
)
329380
},
330-
wantErr: "failed to create job: 500 Internal Server Error",
381+
wantErr: "failed to create job: 500 Internal Server Error: ",
331382
},
332383
{
333384
name: "invalid JSON response, non-HTTP 200",
@@ -364,7 +415,7 @@ func TestCreateJob(t *testing.T) {
364415
cfg := config.NewBlankConfig()
365416
capiClient := NewCAPIClient(httpClient, cfg.Authentication())
366417

367-
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch)
418+
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)
368419

369420
if tt.wantErr != "" {
370421
require.EqualError(t, err, tt.wantErr)

pkg/cmd/agent-task/create/create.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type CreateOptions struct {
3434
Sleep func(d time.Duration)
3535

3636
ProblemStatement string
37+
CustomAgent string
3738
BackOff backoff.BackOff
3839
BaseBranch string
3940
Prompter prompter.Prompter
@@ -103,6 +104,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
103104
104105
# Select a different base branch for the PR
105106
$ gh agent-task create "fix errors" --base branch
107+
108+
# Create a task using the custom agent defined in '.github/agents/my-agent.md'
109+
$ gh agent-task create "build me a new app" --custom-agent my-agent
106110
`),
107111
}
108112

@@ -111,6 +115,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
111115
cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)")
112116
cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)")
113117
cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs")
118+
cmd.Flags().StringVarP(&opts.CustomAgent, "custom-agent", "a", "", "Use a custom agent for the task. e.g., use 'my-agent' for the 'my-agent.md' agent")
114119

115120
return cmd
116121
}
@@ -160,7 +165,7 @@ func createRun(opts *CreateOptions) error {
160165
opts.IO.StartProgressIndicatorWithLabel(fmt.Sprintf("Creating agent task in %s/%s...", repo.RepoOwner(), repo.RepoName()))
161166
defer opts.IO.StopProgressIndicator()
162167

163-
job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch)
168+
job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch, opts.CustomAgent)
164169
if err != nil {
165170
return err
166171
}

0 commit comments

Comments
 (0)