Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cmd/agent-task/capi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const capiHost = "api.githubcopilot.com"
// may be replaced with test doubles in unit tests.
type CapiClient interface {
ListLatestSessionsForViewer(ctx context.Context, limit int) ([]*Session, error)
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error)
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string, customAgent string) (*Job, error)
GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error)
GetSession(ctx context.Context, id string) (*Session, error)
GetSessionLogs(ctx context.Context, id string) ([]byte, error)
Expand Down
14 changes: 10 additions & 4 deletions pkg/cmd/agent-task/capi/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 21 additions & 5 deletions pkg/cmd/agent-task/capi/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
Expand All @@ -18,6 +19,7 @@ type Job struct {
ID string `json:"job_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
ProblemStatement string `json:"problem_statement,omitempty"`
CustomAgent string `json:"custom_agent,omitempty"`
EventType string `json:"event_type,omitempty"`
ContentFilterMode string `json:"content_filter_mode,omitempty"`
Status string `json:"status,omitempty"`
Expand Down Expand Up @@ -54,7 +56,7 @@ const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
// CreateJob queues a new job using the v1 Jobs API. It may or may not
// return Pull Request information. If Pull Request information is required
// following up by polling GetJob with the job ID is necessary.
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) {
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*Job, error) {
if owner == "" || repo == "" {
return nil, errors.New("owner and repo are required")
}
Expand All @@ -71,6 +73,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen

payload := &Job{
ProblemStatement: problemStatement,
CustomAgent: customAgent,
EventType: defaultEventType,
PullRequest: &prOpts,
}
Expand All @@ -88,8 +91,10 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
}
defer res.Body.Close()

body, _ := io.ReadAll(res.Body)

var j Job
if err := json.NewDecoder(res.Body).Decode(&j); err != nil {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&j); err != nil {
if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
// This happens when there's an error like unauthorized (401).
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
Expand All @@ -99,11 +104,22 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
}

if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))

// If the response has error embeded, we can use that.
// TODO: Does this really ever happen?
if j.ErrorInfo != nil {
return nil, fmt.Errorf("failed to create job: %s", j.ErrorInfo.Message)
return nil, fmt.Errorf("failed to create job: %s: %s", statusText, j.ErrorInfo.Message)
}
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
return nil, fmt.Errorf("failed to create job: %s", statusText)

// If the response doesn't have error embedded,
// try to decode the response itself as a jobError.
var errInfo JobError
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&errInfo); err != nil {
return nil, fmt.Errorf("failed to create job: %s", statusText)
}

return nil, fmt.Errorf("failed to create job: %s: %s", statusText, errInfo.Message)
}

return &j, nil
Expand Down
75 changes: 63 additions & 12 deletions pkg/cmd/agent-task/capi/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,14 @@ func TestGetJob(t *testing.T) {
func TestCreateJobRequiresRepoAndProblemStatement(t *testing.T) {
client := &CAPIClient{}

_, err := client.CreateJob(context.Background(), "", "only-repo", "", "")
_, err := client.CreateJob(context.Background(), "", "only-repo", "", "", "")
assert.EqualError(t, err, "owner and repo are required")
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "")
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "", "")
assert.EqualError(t, err, "owner and repo are required")
_, err = client.CreateJob(context.Background(), "", "", "", "")
_, err = client.CreateJob(context.Background(), "", "", "", "", "")
assert.EqualError(t, err, "owner and repo are required")

_, err = client.CreateJob(context.Background(), "owner", "repo", "", "")
_, err = client.CreateJob(context.Background(), "owner", "repo", "", "", "")
assert.EqualError(t, err, "problem statement is required")
}

Expand All @@ -205,11 +205,12 @@ func TestCreateJob(t *testing.T) {
require.NoError(t, err)

tests := []struct {
name string
baseBranch string
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut *Job
name string
baseBranch string
customAgent string
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut *Job
}{
{
name: "success",
Expand Down Expand Up @@ -305,6 +306,56 @@ func TestCreateJob(t *testing.T) {
UpdatedAt: sampleDate,
},
},
{
name: "Success with custom agent",
customAgent: "my-custom-agent",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"),
httpmock.RESTPayload(201,
heredoc.Docf(`
{
"job_id": "job123",
"session_id": "sess1",
"problem_statement": "Do the thing",
"custom_agent": "my-custom-agent",
"event_type": "foo",
"content_filter_mode": "foo",
"status": "foo",
"result": "foo",
"actor": {
"id": 1,
"login": "octocat"
},
"created_at": "%[1]s",
"updated_at": "%[1]s"
}
`, sampleDateString),
func(payload map[string]interface{}) {
assert.Equal(t, "Do the thing", payload["problem_statement"])
assert.Equal(t, "gh_cli", payload["event_type"])
assert.Equal(t, "my-custom-agent", payload["custom_agent"])
},
),
)
},
wantOut: &Job{
ID: "job123",
SessionID: "sess1",
ProblemStatement: "Do the thing",
CustomAgent: "my-custom-agent",
EventType: "foo",
ContentFilterMode: "foo",
Status: "foo",
Result: "foo",
Actor: &JobActor{
ID: 1,
Login: "octocat",
},
CreatedAt: sampleDate,
UpdatedAt: sampleDate,
},
},
{
name: "API error, included in response body",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
Expand All @@ -317,7 +368,7 @@ func TestCreateJob(t *testing.T) {
}`)),
)
},
wantErr: "failed to create job: some error",
wantErr: "failed to create job: 500 Internal Server Error: some error",
},
{
name: "API error",
Expand All @@ -327,7 +378,7 @@ func TestCreateJob(t *testing.T) {
httpmock.StatusStringResponse(500, `{}`),
)
},
wantErr: "failed to create job: 500 Internal Server Error",
wantErr: "failed to create job: 500 Internal Server Error: ",
},
{
name: "invalid JSON response, non-HTTP 200",
Expand Down Expand Up @@ -364,7 +415,7 @@ func TestCreateJob(t *testing.T) {
cfg := config.NewBlankConfig()
capiClient := NewCAPIClient(httpClient, cfg.Authentication())

job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch)
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)

if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
Expand Down
7 changes: 6 additions & 1 deletion pkg/cmd/agent-task/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type CreateOptions struct {
Sleep func(d time.Duration)

ProblemStatement string
CustomAgent string
BackOff backoff.BackOff
BaseBranch string
Prompter prompter.Prompter
Expand Down Expand Up @@ -103,6 +104,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co

# Select a different base branch for the PR
$ gh agent-task create "fix errors" --base branch

# Create a task using the custom agent defined in '.github/agents/my-agent.md'
$ gh agent-task create "build me a new app" --custom-agent my-agent
`),
}

Expand All @@ -111,6 +115,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)")
cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)")
cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs")
cmd.Flags().StringVarP(&opts.CustomAgent, "custom-agent", "a", "", "Use a custom agent for the task. e.g., use 'my-agent' for the 'my-agent.md' agent")

return cmd
}
Expand Down Expand Up @@ -160,7 +165,7 @@ func createRun(opts *CreateOptions) error {
opts.IO.StartProgressIndicatorWithLabel(fmt.Sprintf("Creating agent task in %s/%s...", repo.RepoOwner(), repo.RepoName()))
defer opts.IO.StopProgressIndicator()

job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch)
job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch, opts.CustomAgent)
if err != nil {
return err
}
Expand Down
Loading