diff --git a/core/restc/client.go b/core/restc/client.go new file mode 100644 index 000000000..113241438 --- /dev/null +++ b/core/restc/client.go @@ -0,0 +1,76 @@ +package restc + +import ( + "net/http" + "sync" + "time" +) + +type Client interface { + Verb(verb string) *Request + SetHeader(headers http.Header) +} + +type Opt func(client *client) error + +type client struct { + lock *sync.RWMutex + addr string + + retryTimes int + retryDelay time.Duration + + headers http.Header + + // Set specific behavior of the client. If not set http.DefaultClient will be used. + client *http.Client + + // middleware + beforeRequest []RequestMiddleware +} + +func (c *client) SetHeader(headers http.Header) { + c.headers = headers +} + +type RequestMiddleware func(Client, *Request) error + +func (c *client) requestMiddlewares() []RequestMiddleware { + c.lock.RLock() + defer c.lock.RUnlock() + return c.beforeRequest +} + +func (c *client) executeRequestMiddlewares(req *Request) (err error) { + for _, f := range c.requestMiddlewares() { + if err = f(c, req); err != nil { + return err + } + } + return nil +} + +func (c *client) Verb(verb string) *Request { + return NewRequest(c).Verb(verb) +} + +func NewClient(addr string, ops ...Opt) (Client, error) { + c := &client{ + lock: &sync.RWMutex{}, + addr: addr, + } + + for _, op := range ops { + if err := op(c); err != nil { + return nil, err + } + } + + if c.client == nil { + c.client = &http.Client{} + } + if c.headers == nil { + c.headers = make(http.Header) + } + return c, nil +} diff --git a/core/restc/option.go b/core/restc/option.go new file mode 100644 index 000000000..56f48c4e4 --- /dev/null +++ b/core/restc/option.go @@ -0,0 +1,43 @@ +package restc + +import ( + "net/http" + "time" +) + +func WithHeaders(headers http.Header) Opt { + return func(c *client) error { + c.headers = headers + return nil + } +} + +func WithRetryTimes(times int) Opt { + return func(c *client) error { + c.retryTimes = times + return nil + } +} + +func WithRetryDelay(time time.Duration) Opt { + return func(c *client) error { + c.retryDelay = time + return nil + } +} + +func WithClient(c *http.Client) Opt { + return func(client *client) error { + client.client = c + return nil + } +} + +func WithRequestMiddleware(middleware RequestMiddleware) Opt { + return func(c *client) error { + c.lock.Lock() + defer c.lock.Unlock() + c.beforeRequest = append(c.beforeRequest, middleware) + return nil + } +} diff --git a/core/restc/request.go b/core/restc/request.go new file mode 100644 index 000000000..c390dcc7d --- /dev/null +++ b/core/restc/request.go @@ -0,0 +1,388 @@ +package restc + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "time" + + "github.com/bitly/go-simplejson" + "github.com/pkg/errors" + "github.com/spf13/cast" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +var ( + DefaultCodeField = "code" + DefaultDataField = "data" + DefaultMessageField = "msg" +) + +// Request allows for building up a request to a server in a chained fashion. +// Any errors are stored until the end of your call, so you only have to +// check once. +type Request struct { + c *client + + verb string + headers http.Header + path string + params string + body io.Reader + + err error +} + +func NewRequest(c *client) *Request { + r := &Request{ + c: c, + headers: c.headers.Clone(), + } + return r +} + +func (r *Request) Verb(verb string) *Request { + r.verb = verb + return r +} + +func (r *Request) GetVerb() string { + return r.verb +} + +func (r *Request) GetBody() io.Reader { + return r.body +} + +func (r *Request) GetParams() string { + return r.params +} + +func (r *Request) GetPath() string { + return r.path +} + +func (r *Request) AddHeader(key, value string) { + r.c.lock.Lock() + defer r.c.lock.Unlock() + r.headers.Set(key, value) +} + +type PathParam struct { + Name string + Value any +} + +// Path set path +func (r *Request) Path(path string, args ...PathParam) *Request { + for _, v := range args { + val := reflect.ValueOf(v.Value) + kind := val.Kind() + if kind == reflect.Slice || kind == reflect.Array { + js, err := json.Marshal(v.Value) + if err != nil { + panic(err) + } + path = strings.ReplaceAll(path, "{"+v.Name+"}", cast.ToString(js[1:len(js)-1])) + path = strings.ReplaceAll(path, ":"+v.Name, cast.ToString(js[1:len(js)-1])) + } else { + path = strings.ReplaceAll(path, "{"+v.Name+"}", cast.ToString(v.Value)) + path = strings.ReplaceAll(path, ":"+v.Name, cast.ToString(v.Value)) + } + } + r.path = path + return r +} + +type QueryParam struct { + Name string + Value any +} + +func (r *Request) Params(args ...QueryParam) *Request { + if len(args) == 0 { + return r + } + + var queryParams strings.Builder + queryParams.WriteString("?") + for i, v := range args { + val := reflect.ValueOf(v.Value) + kind := val.Kind() + if kind == reflect.Slice || kind == reflect.Array { + length := val.Len() + for j := 0; j < length; j++ { + value := val.Index(j).Interface() + if cast.ToString(value) == "" { + continue + } + va := url.QueryEscape(cast.ToString(value)) + if i == len(args)-1 && j == length-1 { + queryParams.WriteString(fmt.Sprintf("%s=%s", v.Name, va)) + } else { + queryParams.WriteString(fmt.Sprintf("%s=%s&", v.Name, va)) + } + } + } else { + if cast.ToString(v.Value) == "" { + continue + } + va := url.QueryEscape(cast.ToString(v.Value)) + if i == len(args)-1 { + queryParams.WriteString(fmt.Sprintf("%s=%s", v.Name, va)) + } else { + queryParams.WriteString(fmt.Sprintf("%s=%s&", v.Name, va)) + } + } + } + r.params = queryParams.String() + return r +} + +// Body makes the request use obj as the body. Optional. +func (r *Request) Body(obj any) *Request { + if r.err != nil { + return r + } + + switch t := obj.(type) { + case io.Reader: + r.body = t + case io.ReadCloser: + r.body = t + case string: + r.body = bytes.NewReader([]byte(t)) + case []byte: + r.body = bytes.NewReader(t) + default: + data, err := json.Marshal(obj) + if err != nil { + r.err = err + return r + } + r.body = bytes.NewReader(data) + } + return r +} + +// Result contains the result of calling Request.Do(). +type Result struct { + body []byte + err error + statusCode int + status string +} + +// Do format and executes the request. Returns a Result object for easy response +func (r *Request) Do(ctx context.Context) Result { + if err := r.c.executeRequestMiddlewares(r); err != nil { + return Result{err: err} + } + + request, err := http.NewRequestWithContext(ctx, r.verb, r.c.addr+r.path+r.params, r.body) + if err != nil { + return Result{err: err} + } + + request.Header = r.headers + + var rawResp *http.Response + for k := 0; k <= r.c.retryTimes; k++ { + rawResp, err = r.doRequest(r.c.client, request) + if err == nil || k == r.c.retryTimes { + break + } + + time.Sleep(r.c.retryDelay) + continue + } + + if err != nil { + return Result{err: err} + } + + if rawResp == nil { + return Result{err: errors.New("http response is nil")} + } + + data, err := io.ReadAll(rawResp.Body) + if err != nil { + return Result{err: err} + } + defer rawResp.Body.Close() + + return Result{ + body: data, + err: err, + statusCode: rawResp.StatusCode, + status: rawResp.Status, + } +} + +func (r *Request) doRequest(client *http.Client, request *http.Request) (*http.Response, error) { + res, err := client.Do(request) + if err != nil { + return nil, err + } + if res == nil { + return nil, errors.New("response is nil") + } + return res, nil +} + +type WrapCodeMsgMapping struct { + Code string + Data string + Msg string +} + +type IntoOptions struct { + WrapCodeMsg bool + WrapCodeMsgMapping WrapCodeMsgMapping +} + +// Into stores the result into obj, if possible. If obj is nil it is ignored. +func (r Result) Into(obj any, options *IntoOptions) error { + if reflect.TypeOf(obj).Kind() != reflect.Ptr { + return errors.New("object is not a ptr") + } + + if r.err != nil { + return r.err + } + + if options != nil { + if options.WrapCodeMsg && options.WrapCodeMsgMapping.Code == "" { + options.WrapCodeMsgMapping.Code = DefaultCodeField + } + if options.WrapCodeMsg && options.WrapCodeMsgMapping.Data == "" { + options.WrapCodeMsgMapping.Data = DefaultDataField + } + if options.WrapCodeMsg && options.WrapCodeMsgMapping.Msg == "" { + options.WrapCodeMsgMapping.Msg = DefaultMessageField + } + } + + if r.StatusCode() != http.StatusOK { + s := string(r.body) + + if len(s) == 0 { + return fmt.Errorf("empty response body, status code: %d", r.StatusCode()) + } + + if options != nil && options.WrapCodeMsg { + j, err := simplejson.NewJson(r.body) + if err != nil { + return fmt.Errorf("marsher json error: %v, response body: %v", err, r.body) + } + message, _ := j.Get(options.WrapCodeMsgMapping.Msg).String() + return errors.New(message) + } + return errors.New(s) + } + + j, err := simplejson.NewJson(r.body) + if err != nil { + return err + } + + var marshalJSON []byte + if options != nil && options.WrapCodeMsg { + code, err := j.Get(options.WrapCodeMsgMapping.Code).Int() + if err != nil { + return err + } + if code != http.StatusOK { + message, _ := j.Get(options.WrapCodeMsgMapping.Msg).String() + return errors.New(message) + } + data := j.Get(options.WrapCodeMsgMapping.Data) + marshalJSON, err = data.MarshalJSON() + if err != nil { + return err + } + } else { + marshalJSON, err = j.MarshalJSON() + if err != nil { + return err + } + } + + switch v := obj.(type) { + case proto.Message: + parser := protojson.UnmarshalOptions{ + DiscardUnknown: true, + } + err = parser.Unmarshal(marshalJSON, v) + default: + err = json.Unmarshal(marshalJSON, &obj) + } + + if err != nil { + return err + } + + return nil +} + +// Stream return io.ReadCloser +func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) { + request, err := http.NewRequestWithContext(ctx, r.verb, r.c.addr+r.path+r.params, r.body) + if err != nil { + return nil, err + } + + request.Header = r.headers + + var rawResp *http.Response + for k := 0; k <= r.c.retryTimes; k++ { + rawResp, err = r.doRequest(r.c.client, request) + if err == nil || k == r.c.retryTimes { + break + } + + time.Sleep(r.c.retryDelay) + continue + } + + if err != nil { + return nil, err + } + + if rawResp == nil { + return nil, errors.New("empty resp") + } + + if rawResp.StatusCode != 200 { + return nil, errors.Errorf("unhealthy status code: [%d], status message: [%s]", rawResp.StatusCode, rawResp.Status) + } + + return rawResp.Body, nil +} + +func (r Result) RawResponse() ([]byte, error) { + return r.body, r.err +} + +// Error returns the error executing the request, nil if no error occurred. +func (r Result) Error() error { + return r.err +} + +// StatusCode returns the HTTP status code of the request. (Only valid if no +// error was returned.) +func (r Result) StatusCode() int { + return r.statusCode +} + +// Status returns the status executing the request +func (r Result) Status() string { + return r.status +}