22package github
33
44import (
5+ "bytes"
56 "encoding/json"
67 "fmt"
7- "os/exec"
8- "strconv "
9- "strings "
8+
9+ "github.com/cli/go-gh/v2/pkg/api "
10+ "github.com/cli/go-gh/v2/pkg/repository "
1011)
1112
1213// PR represents a GitHub pull request.
@@ -16,41 +17,138 @@ type PR struct {
1617 Merged bool `json:"merged"`
1718}
1819
19- // CreatePR creates a new pull request and returns the PR number.
20- func CreatePR (base , title , body string ) (int , error ) {
21- args := []string {"pr" , "create" , "--base" , base , "--title" , title , "--body" , body }
22- out , err := exec .Command ("gh" , args ... ).Output ()
20+ // Client wraps the go-gh REST client with repo context.
21+ type Client struct {
22+ rest * api.RESTClient
23+ owner string
24+ repo string
25+ }
26+
27+ // NewClient creates a new GitHub client for the current repository.
28+ func NewClient () (* Client , error ) {
29+ rest , err := api .DefaultRESTClient ()
2330 if err != nil {
24- if exitErr , ok := err .(* exec.ExitError ); ok {
25- return 0 , fmt .Errorf ("gh pr create failed: %s" , string (exitErr .Stderr ))
26- }
27- return 0 , fmt .Errorf ("gh pr create failed: %w" , err )
31+ return nil , fmt .Errorf ("failed to create REST client: %w" , err )
2832 }
2933
30- // Output is the PR URL, extract the number
31- url := strings .TrimSpace (string (out ))
32- parts := strings .Split (url , "/" )
33- if len (parts ) == 0 {
34- return 0 , fmt .Errorf ("unexpected output: %s" , url )
34+ repo , err := repository .Current ()
35+ if err != nil {
36+ return nil , fmt .Errorf ("failed to detect repository: %w" , err )
3537 }
36- return strconv .Atoi (parts [len (parts )- 1 ])
38+
39+ return & Client {
40+ rest : rest ,
41+ owner : repo .Owner ,
42+ repo : repo .Name ,
43+ }, nil
3744}
3845
39- // GetPR fetches PR details by number.
40- func GetPR (number int ) (* PR , error ) {
41- out , err := exec .Command ("gh" , "pr" , "view" , strconv .Itoa (number ), "--json" , "number,state,merged" ).Output ()
46+ // CreatePR creates a new pull request and returns the PR number.
47+ func (c * Client ) CreatePR (head , base , title , body string ) (int , error ) {
48+ path := fmt .Sprintf ("repos/%s/%s/pulls" , c .owner , c .repo )
49+
50+ request := struct {
51+ Head string `json:"head"`
52+ Base string `json:"base"`
53+ Title string `json:"title"`
54+ Body string `json:"body"`
55+ }{
56+ Head : head ,
57+ Base : base ,
58+ Title : title ,
59+ Body : body ,
60+ }
61+
62+ reqBody , err := json .Marshal (request )
4263 if err != nil {
43- return nil , err
64+ return 0 , fmt . Errorf ( "failed to marshal request: %w" , err )
4465 }
4566
67+ var response PR
68+ err = c .rest .Post (path , bytes .NewReader (reqBody ), & response )
69+ if err != nil {
70+ return 0 , fmt .Errorf ("failed to create PR: %w" , err )
71+ }
72+
73+ return response .Number , nil
74+ }
75+
76+ // GetPR fetches PR details by number.
77+ func (c * Client ) GetPR (number int ) (* PR , error ) {
78+ path := fmt .Sprintf ("repos/%s/%s/pulls/%d" , c .owner , c .repo , number )
79+
4680 var pr PR
47- if err := json .Unmarshal (out , & pr ); err != nil {
48- return nil , err
81+ err := c .rest .Get (path , & pr )
82+ if err != nil {
83+ return nil , fmt .Errorf ("failed to get PR #%d: %w" , number , err )
4984 }
85+
5086 return & pr , nil
5187}
5288
5389// UpdatePRBase updates the base branch of a PR.
90+ func (c * Client ) UpdatePRBase (number int , base string ) error {
91+ path := fmt .Sprintf ("repos/%s/%s/pulls/%d" , c .owner , c .repo , number )
92+
93+ request := struct {
94+ Base string `json:"base"`
95+ }{
96+ Base : base ,
97+ }
98+
99+ reqBody , err := json .Marshal (request )
100+ if err != nil {
101+ return fmt .Errorf ("failed to marshal request: %w" , err )
102+ }
103+
104+ err = c .rest .Patch (path , bytes .NewReader (reqBody ), nil )
105+ if err != nil {
106+ return fmt .Errorf ("failed to update PR #%d base: %w" , number , err )
107+ }
108+
109+ return nil
110+ }
111+
112+ // --- Convenience functions for backward compatibility ---
113+
114+ // defaultClient is a lazily-initialized client for convenience functions.
115+ var defaultClient * Client
116+
117+ func getDefaultClient () (* Client , error ) {
118+ if defaultClient == nil {
119+ var err error
120+ defaultClient , err = NewClient ()
121+ if err != nil {
122+ return nil , err
123+ }
124+ }
125+ return defaultClient , nil
126+ }
127+
128+ // CreatePR creates a new pull request using the default client.
129+ // Deprecated: Use NewClient() and call methods directly for better error handling.
130+ func CreatePR (head , base , title , body string ) (int , error ) {
131+ client , err := getDefaultClient ()
132+ if err != nil {
133+ return 0 , err
134+ }
135+ return client .CreatePR (head , base , title , body )
136+ }
137+
138+ // GetPR fetches PR details using the default client.
139+ func GetPR (number int ) (* PR , error ) {
140+ client , err := getDefaultClient ()
141+ if err != nil {
142+ return nil , err
143+ }
144+ return client .GetPR (number )
145+ }
146+
147+ // UpdatePRBase updates the base branch using the default client.
54148func UpdatePRBase (number int , base string ) error {
55- return exec .Command ("gh" , "pr" , "edit" , strconv .Itoa (number ), "--base" , base ).Run ()
149+ client , err := getDefaultClient ()
150+ if err != nil {
151+ return err
152+ }
153+ return client .UpdatePRBase (number , base )
56154}
0 commit comments