Skip to content
116 changes: 78 additions & 38 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@ import (
"net/http"
"strings"
"sync"

"fmt"
"io"
"os"
"reflect"
"runtime"
)

// Handle is a function that can be registered to a route to handle HTTP
// requests. Like http.HandlerFunc, but has a third parameter for the values of
// wildcards (path variables).
type Handle func(http.ResponseWriter, *http.Request, Params)

var DefaultWriter io.Writer = os.Stdout

// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Expand Down Expand Up @@ -231,54 +239,61 @@ func (r *Router) putParams(ps *Params) {
}
}

func (r *Router) saveMatchedRoutePath(path string, handle Handle) Handle {
return func(w http.ResponseWriter, req *http.Request, ps Params) {
if ps == nil {
psp := r.getParams()
ps = (*psp)[0:1]
ps[0] = Param{Key: MatchedRoutePathParam, Value: path}
handle(w, req, ps)
r.putParams(psp)
} else {
ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path})
handle(w, req, ps)
// If enabled, adds the matched route path onto the http.Request context
// before invoking the handler.
// The matched route path is only added to handlers of routes that were
// registered when this option was enabled.
func (r *Router) saveMatchedRoutePath(path string, handles ...Handle) HandlersChain {
for i, handle := range handles {
handles[i] = func(w http.ResponseWriter, req *http.Request, ps Params) {
if ps == nil {
psp := r.getParams()
ps = (*psp)[0:1]
ps[0] = Param{Key: MatchedRoutePathParam, Value: path}
handle(w, req, ps)
r.putParams(psp)
} else {
ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path})
handle(w, req, ps)
}
}
}
return handles;
}

// GET is a shortcut for router.Handle(http.MethodGet, path, handle)
func (r *Router) GET(path string, handle Handle) {
r.Handle(http.MethodGet, path, handle)
func (r *Router) GET(path string, handle ...Handle) {
r.Handle(http.MethodGet, path, handle...)
}

// HEAD is a shortcut for router.Handle(http.MethodHead, path, handle)
func (r *Router) HEAD(path string, handle Handle) {
r.Handle(http.MethodHead, path, handle)
func (r *Router) HEAD(path string, handle ...Handle) {
r.Handle(http.MethodHead, path, handle...)
}

// OPTIONS is a shortcut for router.Handle(http.MethodOptions, path, handle)
func (r *Router) OPTIONS(path string, handle Handle) {
r.Handle(http.MethodOptions, path, handle)
func (r *Router) OPTIONS(path string, handle ...Handle) {
r.Handle(http.MethodOptions, path, handle...)
}

// POST is a shortcut for router.Handle(http.MethodPost, path, handle)
func (r *Router) POST(path string, handle Handle) {
r.Handle(http.MethodPost, path, handle)
func (r *Router) POST(path string, handle ...Handle) {
r.Handle(http.MethodPost, path, handle...)
}

// PUT is a shortcut for router.Handle(http.MethodPut, path, handle)
func (r *Router) PUT(path string, handle Handle) {
r.Handle(http.MethodPut, path, handle)
func (r *Router) PUT(path string, handle ...Handle) {
r.Handle(http.MethodPut, path, handle...)
}

// PATCH is a shortcut for router.Handle(http.MethodPatch, path, handle)
func (r *Router) PATCH(path string, handle Handle) {
r.Handle(http.MethodPatch, path, handle)
func (r *Router) PATCH(path string, handle ...Handle) {
r.Handle(http.MethodPatch, path, handle...)
}

// DELETE is a shortcut for router.Handle(http.MethodDelete, path, handle)
func (r *Router) DELETE(path string, handle Handle) {
r.Handle(http.MethodDelete, path, handle)
func (r *Router) DELETE(path string, handle ...Handle) {
r.Handle(http.MethodDelete, path, handle...)
}

// Handle registers a new request handle with the given path and method.
Expand All @@ -289,7 +304,7 @@ func (r *Router) DELETE(path string, handle Handle) {
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (r *Router) Handle(method, path string, handle Handle) {
func (r *Router) Handle(method, path string, handles ...Handle) {
varsCount := uint16(0)

if method == "" {
Expand All @@ -298,19 +313,29 @@ func (r *Router) Handle(method, path string, handle Handle) {
if len(path) < 1 || path[0] != '/' {
panic("path must begin with '/' in path '" + path + "'")
}
if handle == nil {
//HandlersChain is nil or array[0] is nil
if handles == nil || (len(handles)==1 && handles[0]==nil){
panic("handle must not be nil")
}

if r.SaveMatchedRoutePath {
varsCount++
handle = r.saveMatchedRoutePath(path, handle)
handles = r.saveMatchedRoutePath(path, handles...)
}

if r.trees == nil {
r.trees = make(map[string]*node)
}

handleNames := ""
for _, handle := range handles {
if handle != nil{
handleNames = handleNames+","+runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name()
}
}
fmt.Fprintf(DefaultWriter, "[router-debug] %-6s %-40s --> [%s] (%d handlers)\n", method, path,handleNames,len(handles) )


root := r.trees[method]
if root == nil {
root = new(node)
Expand All @@ -319,7 +344,7 @@ func (r *Router) Handle(method, path string, handle Handle) {
r.globalAllowed = r.allowed("*", "")
}

root.addRoute(path, handle)
root.addRoute(path, handles)

// Update maxParams
if paramsCount := countParams(path); paramsCount+varsCount > r.maxParams {
Expand Down Expand Up @@ -391,17 +416,17 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) {
// If the path was found, it returns the handle function and the path parameter
// values. Otherwise the third return value indicates whether a redirection to
// the same path with an extra / without the trailing slash should be performed.
func (r *Router) Lookup(method, path string) (Handle, Params, bool) {
func (r *Router) Lookup(method, path string) (HandlersChain, Params, bool) {
if root := r.trees[method]; root != nil {
handle, ps, tsr := root.getValue(path, r.getParams)
if handle == nil {
handles, ps, tsr := root.getValue(path, r.getParams)
if handles == nil {
r.putParams(ps)
return nil, nil, tsr
}
if ps == nil {
return handle, nil, tsr
return handles, nil, tsr
}
return handle, *ps, tsr
return handles, *ps, tsr
}
return nil, nil, false
}
Expand Down Expand Up @@ -464,14 +489,29 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

path := req.URL.Path
//输出请求的处理日志
reqLog := fmt.Sprintf("[http-router] %-6s req.uri=%-12s", req.Method,req.RequestURI)

if root := r.trees[req.Method]; root != nil {
if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil {
if handles, ps, tsr := root.getValue(path, r.getParams); handles != nil {
if ps != nil {
handle(w, req, *ps)
r.putParams(ps)
for _,handle := range handles {
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name())
handle(w, req, *ps)
r.putParams(ps)
}
} else {
handle(w, req, nil)
for _,handle := range handles {
if handle != nil{
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name())
handle(w, req, nil)
}else{
//输出请求的处理日志
fmt.Fprintf(DefaultWriter, reqLog+",req.handle is nil,do nothing.\n")
}
}
}
return
} else if req.Method != http.MethodConnect && path != "/" {
Expand Down
64 changes: 38 additions & 26 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ func TestRouterAPI(t *testing.T) {
httpHandler := handlerStruct{&handler}

router := New()
router.GET("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
})

router.GET("/GET",
func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
},func(w http.ResponseWriter, r *http.Request, _ Params) {
get = true
},nil,
)
router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
head = true
})
Expand Down Expand Up @@ -509,56 +514,63 @@ func TestRouterLookup(t *testing.T) {
router := New()

// try empty router first
handle, _, tsr := router.Lookup(http.MethodGet, "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr := router.Lookup(http.MethodGet, "/nope")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
}

// insert route and try again
router.GET("/user/:name", wantHandle)
handle, params, _ := router.Lookup(http.MethodGet, "/user/gopher")
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
handles, params, _ := router.Lookup(http.MethodGet, "/user/gopher")

for _,handle := range handles {
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
}
}
}

if !reflect.DeepEqual(params, wantParams) {
t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
}
routed = false

// route without param
router.GET("/user", wantHandle)
handle, params, _ = router.Lookup(http.MethodGet, "/user")
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
handles, params, _ = router.Lookup(http.MethodGet, "/user")
for _,handle := range handles {
if handle == nil {
t.Fatal("Got no handle!")
} else {
handle(nil, nil, nil)
if !routed {
t.Fatal("Routing failed!")
}
}
}

if params != nil {
t.Fatalf("Wrong parameter values: want %v, got %v", nil, params)
}

handle, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if !tsr {
t.Error("Got no TSR recommendation!")
}

handle, _, tsr = router.Lookup(http.MethodGet, "/nope")
if handle != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handle)
handles, _, tsr = router.Lookup(http.MethodGet, "/nope")
if handles != nil {
t.Fatalf("Got handle for unregistered pattern: %v", handles)
}
if tsr {
t.Error("Got wrong TSR recommendation!")
Expand Down
Loading