diff --git a/router.go b/router.go index 1eab403d..d2f1aa50 100644 --- a/router.go +++ b/router.go @@ -81,6 +81,12 @@ import ( "net/http" "strings" "sync" + + "fmt" + "io" + "os" + "reflect" + "runtime" ) // Handle is a function that can be registered to a route to handle HTTP @@ -88,6 +94,8 @@ import ( // wildcards (path variables). type Handle func(http.ResponseWriter, *http.Request, Params) +var DefaultWriter io.Writer = os.Stdout + // Param is a single URL parameter, consisting of a key and a value. type Param struct { Key string @@ -231,54 +239,61 @@ func (r *Router) putParams(ps *Params) { } } -func (r *Router) saveMatchedRoutePath(path string, handle Handle) Handle { - return func(w http.ResponseWriter, req *http.Request, ps Params) { - if ps == nil { - psp := r.getParams() - ps = (*psp)[0:1] - ps[0] = Param{Key: MatchedRoutePathParam, Value: path} - handle(w, req, ps) - r.putParams(psp) - } else { - ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path}) - handle(w, req, ps) +// If enabled, adds the matched route path onto the http.Request context +// before invoking the handler. +// The matched route path is only added to handlers of routes that were +// registered when this option was enabled. +func (r *Router) saveMatchedRoutePath(path string, handles ...Handle) HandlersChain { + for i, handle := range handles { + handles[i] = func(w http.ResponseWriter, req *http.Request, ps Params) { + if ps == nil { + psp := r.getParams() + ps = (*psp)[0:1] + ps[0] = Param{Key: MatchedRoutePathParam, Value: path} + handle(w, req, ps) + r.putParams(psp) + } else { + ps = append(ps, Param{Key: MatchedRoutePathParam, Value: path}) + handle(w, req, ps) + } } } + return handles; } // GET is a shortcut for router.Handle(http.MethodGet, path, handle) -func (r *Router) GET(path string, handle Handle) { - r.Handle(http.MethodGet, path, handle) +func (r *Router) GET(path string, handle ...Handle) { + r.Handle(http.MethodGet, path, handle...) } // HEAD is a shortcut for router.Handle(http.MethodHead, path, handle) -func (r *Router) HEAD(path string, handle Handle) { - r.Handle(http.MethodHead, path, handle) +func (r *Router) HEAD(path string, handle ...Handle) { + r.Handle(http.MethodHead, path, handle...) } // OPTIONS is a shortcut for router.Handle(http.MethodOptions, path, handle) -func (r *Router) OPTIONS(path string, handle Handle) { - r.Handle(http.MethodOptions, path, handle) +func (r *Router) OPTIONS(path string, handle ...Handle) { + r.Handle(http.MethodOptions, path, handle...) } // POST is a shortcut for router.Handle(http.MethodPost, path, handle) -func (r *Router) POST(path string, handle Handle) { - r.Handle(http.MethodPost, path, handle) +func (r *Router) POST(path string, handle ...Handle) { + r.Handle(http.MethodPost, path, handle...) } // PUT is a shortcut for router.Handle(http.MethodPut, path, handle) -func (r *Router) PUT(path string, handle Handle) { - r.Handle(http.MethodPut, path, handle) +func (r *Router) PUT(path string, handle ...Handle) { + r.Handle(http.MethodPut, path, handle...) } // PATCH is a shortcut for router.Handle(http.MethodPatch, path, handle) -func (r *Router) PATCH(path string, handle Handle) { - r.Handle(http.MethodPatch, path, handle) +func (r *Router) PATCH(path string, handle ...Handle) { + r.Handle(http.MethodPatch, path, handle...) } // DELETE is a shortcut for router.Handle(http.MethodDelete, path, handle) -func (r *Router) DELETE(path string, handle Handle) { - r.Handle(http.MethodDelete, path, handle) +func (r *Router) DELETE(path string, handle ...Handle) { + r.Handle(http.MethodDelete, path, handle...) } // Handle registers a new request handle with the given path and method. @@ -289,7 +304,7 @@ func (r *Router) DELETE(path string, handle Handle) { // This function is intended for bulk loading and to allow the usage of less // frequently used, non-standardized or custom methods (e.g. for internal // communication with a proxy). -func (r *Router) Handle(method, path string, handle Handle) { +func (r *Router) Handle(method, path string, handles ...Handle) { varsCount := uint16(0) if method == "" { @@ -298,19 +313,29 @@ func (r *Router) Handle(method, path string, handle Handle) { if len(path) < 1 || path[0] != '/' { panic("path must begin with '/' in path '" + path + "'") } - if handle == nil { + //HandlersChain is nil or array[0] is nil + if handles == nil || (len(handles)==1 && handles[0]==nil){ panic("handle must not be nil") } if r.SaveMatchedRoutePath { varsCount++ - handle = r.saveMatchedRoutePath(path, handle) + handles = r.saveMatchedRoutePath(path, handles...) } if r.trees == nil { r.trees = make(map[string]*node) } + handleNames := "" + for _, handle := range handles { + if handle != nil{ + handleNames = handleNames+","+runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name() + } + } + fmt.Fprintf(DefaultWriter, "[router-debug] %-6s %-40s --> [%s] (%d handlers)\n", method, path,handleNames,len(handles) ) + + root := r.trees[method] if root == nil { root = new(node) @@ -319,7 +344,7 @@ func (r *Router) Handle(method, path string, handle Handle) { r.globalAllowed = r.allowed("*", "") } - root.addRoute(path, handle) + root.addRoute(path, handles) // Update maxParams if paramsCount := countParams(path); paramsCount+varsCount > r.maxParams { @@ -391,17 +416,17 @@ func (r *Router) recv(w http.ResponseWriter, req *http.Request) { // If the path was found, it returns the handle function and the path parameter // values. Otherwise the third return value indicates whether a redirection to // the same path with an extra / without the trailing slash should be performed. -func (r *Router) Lookup(method, path string) (Handle, Params, bool) { +func (r *Router) Lookup(method, path string) (HandlersChain, Params, bool) { if root := r.trees[method]; root != nil { - handle, ps, tsr := root.getValue(path, r.getParams) - if handle == nil { + handles, ps, tsr := root.getValue(path, r.getParams) + if handles == nil { r.putParams(ps) return nil, nil, tsr } if ps == nil { - return handle, nil, tsr + return handles, nil, tsr } - return handle, *ps, tsr + return handles, *ps, tsr } return nil, nil, false } @@ -464,14 +489,29 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } path := req.URL.Path + //输出请求的处理日志 + reqLog := fmt.Sprintf("[http-router] %-6s req.uri=%-12s", req.Method,req.RequestURI) if root := r.trees[req.Method]; root != nil { - if handle, ps, tsr := root.getValue(path, r.getParams); handle != nil { + if handles, ps, tsr := root.getValue(path, r.getParams); handles != nil { if ps != nil { - handle(w, req, *ps) - r.putParams(ps) + for _,handle := range handles { + //输出请求的处理日志 + fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name()) + handle(w, req, *ps) + r.putParams(ps) + } } else { - handle(w, req, nil) + for _,handle := range handles { + if handle != nil{ + //输出请求的处理日志 + fmt.Fprintf(DefaultWriter, reqLog+",req.handle=%-40s\n", runtime.FuncForPC(reflect.ValueOf(handle).Pointer()).Name()) + handle(w, req, nil) + }else{ + //输出请求的处理日志 + fmt.Fprintf(DefaultWriter, reqLog+",req.handle is nil,do nothing.\n") + } + } } return } else if req.Method != http.MethodConnect && path != "/" { diff --git a/router_test.go b/router_test.go index ae7d2435..6d22e251 100644 --- a/router_test.go +++ b/router_test.go @@ -81,9 +81,14 @@ func TestRouterAPI(t *testing.T) { httpHandler := handlerStruct{&handler} router := New() - router.GET("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) { - get = true - }) + + router.GET("/GET", + func(w http.ResponseWriter, r *http.Request, _ Params) { + get = true + },func(w http.ResponseWriter, r *http.Request, _ Params) { + get = true + },nil, + ) router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) { head = true }) @@ -509,9 +514,9 @@ func TestRouterLookup(t *testing.T) { router := New() // try empty router first - handle, _, tsr := router.Lookup(http.MethodGet, "/nope") - if handle != nil { - t.Fatalf("Got handle for unregistered pattern: %v", handle) + handles, _, tsr := router.Lookup(http.MethodGet, "/nope") + if handles != nil { + t.Fatalf("Got handle for unregistered pattern: %v", handles) } if tsr { t.Error("Got wrong TSR recommendation!") @@ -519,15 +524,19 @@ func TestRouterLookup(t *testing.T) { // insert route and try again router.GET("/user/:name", wantHandle) - handle, params, _ := router.Lookup(http.MethodGet, "/user/gopher") - if handle == nil { - t.Fatal("Got no handle!") - } else { - handle(nil, nil, nil) - if !routed { - t.Fatal("Routing failed!") + handles, params, _ := router.Lookup(http.MethodGet, "/user/gopher") + + for _,handle := range handles { + if handle == nil { + t.Fatal("Got no handle!") + } else { + handle(nil, nil, nil) + if !routed { + t.Fatal("Routing failed!") + } } } + if !reflect.DeepEqual(params, wantParams) { t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params) } @@ -535,30 +544,33 @@ func TestRouterLookup(t *testing.T) { // route without param router.GET("/user", wantHandle) - handle, params, _ = router.Lookup(http.MethodGet, "/user") - if handle == nil { - t.Fatal("Got no handle!") - } else { - handle(nil, nil, nil) - if !routed { - t.Fatal("Routing failed!") + handles, params, _ = router.Lookup(http.MethodGet, "/user") + for _,handle := range handles { + if handle == nil { + t.Fatal("Got no handle!") + } else { + handle(nil, nil, nil) + if !routed { + t.Fatal("Routing failed!") + } } } + if params != nil { t.Fatalf("Wrong parameter values: want %v, got %v", nil, params) } - handle, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/") - if handle != nil { - t.Fatalf("Got handle for unregistered pattern: %v", handle) + handles, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/") + if handles != nil { + t.Fatalf("Got handle for unregistered pattern: %v", handles) } if !tsr { t.Error("Got no TSR recommendation!") } - handle, _, tsr = router.Lookup(http.MethodGet, "/nope") - if handle != nil { - t.Fatalf("Got handle for unregistered pattern: %v", handle) + handles, _, tsr = router.Lookup(http.MethodGet, "/nope") + if handles != nil { + t.Fatalf("Got handle for unregistered pattern: %v", handles) } if tsr { t.Error("Got wrong TSR recommendation!") diff --git a/tree.go b/tree.go index 6eb4fe67..8d1bfb9e 100644 --- a/tree.go +++ b/tree.go @@ -71,6 +71,9 @@ const ( catchAll ) +//HandlersChain defines a HandlerFunc array. +type HandlersChain []Handle + type node struct { path string indices string @@ -78,7 +81,7 @@ type node struct { nType nodeType priority uint32 children []*node - handle Handle + handlers HandlersChain } // Increments priority of the given child and reorders if necessary @@ -106,13 +109,13 @@ func (n *node) incrementChildPrio(pos int) int { // addRoute adds a node with the given handle to the path. // Not concurrency-safe! -func (n *node) addRoute(path string, handle Handle) { +func (n *node) addRoute(path string, handlers HandlersChain) { fullPath := path n.priority++ // Empty tree if n.path == "" && n.indices == "" { - n.insertChild(path, fullPath, handle) + n.insertChild(path, fullPath, handlers) n.nType = root return } @@ -132,7 +135,7 @@ walk: nType: static, indices: n.indices, children: n.children, - handle: n.handle, + handlers: n.handlers, priority: n.priority - 1, } @@ -140,7 +143,7 @@ walk: // []byte for proper unicode char conversion, see #65 n.indices = string([]byte{n.path[i]}) n.path = path[:i] - n.handle = nil + n.handlers = nil n.wildChild = false } @@ -201,20 +204,20 @@ walk: n.incrementChildPrio(len(n.indices) - 1) n = child } - n.insertChild(path, fullPath, handle) + n.insertChild(path, fullPath, handlers) return } // Otherwise add handle to current node - if n.handle != nil { + if n.handlers != nil { panic("a handle is already registered for path '" + fullPath + "'") } - n.handle = handle + n.handlers = handlers return } } -func (n *node) insertChild(path, fullPath string, handle Handle) { +func (n *node) insertChild(path, fullPath string, handlers HandlersChain) { for { // Find prefix until first wildcard wildcard, i, valid := findWildcard(path) @@ -270,7 +273,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { } // Otherwise we're done. Insert the handle in the new leaf - n.handle = handle + n.handlers = handlers return } @@ -305,7 +308,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { child = &node{ path: path[i:], nType: catchAll, - handle: handle, + handlers: handlers, priority: 1, } n.children = []*node{child} @@ -315,7 +318,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { // If no wildcard was found, simply insert the path and handle n.path = path - n.handle = handle + n.handlers = handlers } // Returns the handle registered with the given path (key). The values of @@ -323,7 +326,7 @@ func (n *node) insertChild(path, fullPath string, handle Handle) { // If no handle can be found, a TSR (trailing slash redirect) recommendation is // made if a handle exists with an extra (without the) trailing slash for the // given path. -func (n *node) getValue(path string, params func() *Params) (handle Handle, ps *Params, tsr bool) { +func (n *node) getValue(path string, params func() *Params) (handlers HandlersChain, ps *Params, tsr bool) { walk: // Outer loop for walking the tree for { prefix := n.path @@ -346,7 +349,7 @@ walk: // Outer loop for walking the tree // Nothing found. // We can recommend to redirect to the same URL without a // trailing slash if a leaf exists for that path. - tsr = (path == "/" && n.handle != nil) + tsr = (path == "/" && n.handlers != nil) return } @@ -387,13 +390,13 @@ walk: // Outer loop for walking the tree return } - if handle = n.handle; handle != nil { + if handlers = n.handlers; handlers != nil { return } else if len(n.children) == 1 { // No handle found. Check if a handle for this path + a // trailing slash exists for TSR recommendation n = n.children[0] - tsr = (n.path == "/" && n.handle != nil) || (n.path == "" && n.indices == "/") + tsr = (n.path == "/" && n.handlers != nil) || (n.path == "" && n.indices == "/") } return @@ -413,7 +416,7 @@ walk: // Outer loop for walking the tree } } - handle = n.handle + handlers = n.handlers return default: @@ -423,7 +426,7 @@ walk: // Outer loop for walking the tree } else if path == prefix { // We should have reached the node containing the handle. // Check if this node has a handle registered. - if handle = n.handle; handle != nil { + if handlers = n.handlers; handlers != nil { return } @@ -440,8 +443,8 @@ walk: // Outer loop for walking the tree for i, c := range []byte(n.indices) { if c == '/' { n = n.children[i] - tsr = (len(n.path) == 1 && n.handle != nil) || - (n.nType == catchAll && n.children[0].handle != nil) + tsr = (len(n.path) == 1 && n.handlers != nil) || + (n.nType == catchAll && n.children[0].handlers != nil) return } } @@ -452,7 +455,7 @@ walk: // Outer loop for walking the tree // extra trailing slash if a leaf exists for that path tsr = (path == "/") || (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && - path == prefix[:len(prefix)-1] && n.handle != nil) + path == prefix[:len(prefix)-1] && n.handlers != nil) return } } @@ -587,7 +590,7 @@ walk: // Outer loop for walking the tree // Nothing found. We can recommend to redirect to the same URL // without a trailing slash if a leaf exists for that path - if fixTrailingSlash && path == "/" && n.handle != nil { + if fixTrailingSlash && path == "/" && n.handlers != nil { return ciPath } return nil @@ -622,13 +625,13 @@ walk: // Outer loop for walking the tree return nil } - if n.handle != nil { + if n.handlers != nil { return ciPath } else if fixTrailingSlash && len(n.children) == 1 { // No handle found. Check if a handle for this path + a // trailing slash exists n = n.children[0] - if n.path == "/" && n.handle != nil { + if n.path == "/" && n.handlers != nil { return append(ciPath, '/') } } @@ -643,7 +646,7 @@ walk: // Outer loop for walking the tree } else { // We should have reached the node containing the handle. // Check if this node has a handle registered. - if n.handle != nil { + if n.handlers != nil { return ciPath } @@ -653,8 +656,8 @@ walk: // Outer loop for walking the tree for i, c := range []byte(n.indices) { if c == '/' { n = n.children[i] - if (len(n.path) == 1 && n.handle != nil) || - (n.nType == catchAll && n.children[0].handle != nil) { + if (len(n.path) == 1 && n.handlers != nil) || + (n.nType == catchAll && n.children[0].handlers != nil) { return append(ciPath, '/') } return nil @@ -672,7 +675,7 @@ walk: // Outer loop for walking the tree return ciPath } if len(path)+1 == npLen && n.path[len(path)] == '/' && - strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handle != nil { + strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil { return append(ciPath, n.path...) } } diff --git a/tree_test.go b/tree_test.go index 2209abcd..e39bb6ff 100644 --- a/tree_test.go +++ b/tree_test.go @@ -25,10 +25,12 @@ import ( // Used as a workaround since we can't compare functions or their addresses var fakeHandlerValue string - -func fakeHandler(val string) Handle { - return func(http.ResponseWriter, *http.Request, Params) { - fakeHandlerValue = val +//HandlersChain +func fakeHandler(val string) HandlersChain { + return []Handle{ + func(http.ResponseWriter, *http.Request, Params) { + fakeHandlerValue = val + }, } } @@ -46,29 +48,30 @@ func getParams() *Params { func checkRequests(t *testing.T, tree *node, requests testRequests) { for _, request := range requests { - handler, psp, _ := tree.getValue(request.path, getParams) - - switch { - case handler == nil: - if !request.nilHandler { - t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path) + handlers, psp, _ := tree.getValue(request.path, getParams) + for _,handler := range handlers { + switch { + case handler == nil: + if !request.nilHandler { + t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path) + } + case request.nilHandler: + t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path) + default: + handler(nil, nil, nil) + if fakeHandlerValue != request.route { + t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route) + } } - case request.nilHandler: - t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path) - default: - handler(nil, nil, nil) - if fakeHandlerValue != request.route { - t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route) - } - } - var ps Params - if psp != nil { - ps = *psp - } + var ps Params + if psp != nil { + ps = *psp + } - if !reflect.DeepEqual(ps, request.ps) { - t.Errorf("Params mismatch for route '%s'", request.path) + if !reflect.DeepEqual(ps, request.ps) { + t.Errorf("Params mismatch for route '%s'", request.path) + } } } } @@ -79,7 +82,7 @@ func checkPriorities(t *testing.T, n *node) uint32 { prio += checkPriorities(t, n.children[i]) } - if n.handle != nil { + if n.handlers != nil { prio++ }