Moving the method comparison in gorilla mux to the end and not in the chain

This allows mux to compare and match the requests based on the HTTP method as well. If we compare the methods in the middleware chain, then mux will try to redirect the request to the first match it finds.
This commit is contained in:
Arpit Mohan 2019-03-16 17:05:20 +05:30
parent 1133b53437
commit 49b7051cea
4 changed files with 19 additions and 33 deletions

View File

@ -29,27 +29,6 @@ func Logging() Middleware {
}
}
// Method ensures that url can only be requested with a specific method, else returns a 400 Bad Request
func Method(m string) Middleware {
// Create a new Middleware
return func(f http.HandlerFunc) http.HandlerFunc {
// Define the http.HandlerFunc
return func(w http.ResponseWriter, r *http.Request) {
// Do middleware things
if r.Method != m {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
// Call the next middleware/handler in chain
f(w, r)
}
}
}
func Authenticated() Middleware {
// Create a new Middleware

View File

@ -3,7 +3,11 @@ package api
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"github.com/gorilla/mux"
"gitlab.com/mobtools/internal-tools-server/models"
"gitlab.com/mobtools/internal-tools-server/services"
@ -64,10 +68,13 @@ func UpdateQuery(w http.ResponseWriter, r *http.Request) {
queryBody := models.Query{}
err := json.NewDecoder(r.Body).Decode(&queryBody)
if err != nil {
log.Printf("Got error when decoding the queryBody. %s", err.Error())
HandleAPIError(w, r, err)
return
}
queryBody.ID, _ = strconv.ParseInt(mux.Vars(r)["id"], 10, 64)
log.Printf("Got query.ID as %d", queryBody.ID)
queryBody, err = services.UpdateQuery(queryBody)
if err != nil {
HandleAPIError(w, r, err)

View File

@ -49,25 +49,25 @@ func intializeServer() *mux.Router {
}
// Auth Endpoints
router.HandleFunc(url.LoginURL, middleware.Chain(api.Login, middleware.Method("GET"), middleware.Logging()))
router.HandleFunc(url.AuthURL, middleware.Chain(api.InitiateAuth, middleware.Method("GET"), middleware.Logging()))
router.HandleFunc(url.AuthCallbackURL, middleware.Chain(api.AuthCallback, middleware.Method("GET"), middleware.Logging()))
router.HandleFunc(url.LogoutURL, middleware.Chain(api.Logout, middleware.Method("GET"), middleware.Logging()))
router.HandleFunc(url.ProfileURL, middleware.Chain(api.GetUserProfile, middleware.Method("GET"), middleware.Logging()))
router.HandleFunc(url.LoginURL, middleware.Chain(api.Login, middleware.Logging())).Methods("GET")
router.HandleFunc(url.AuthURL, middleware.Chain(api.InitiateAuth, middleware.Logging())).Methods("GET")
router.HandleFunc(url.AuthCallbackURL, middleware.Chain(api.AuthCallback, middleware.Logging())).Methods("GET")
router.HandleFunc(url.LogoutURL, middleware.Chain(api.Logout, middleware.Logging())).Methods("GET")
router.HandleFunc(url.ProfileURL, middleware.Chain(api.GetUserProfile, middleware.Logging())).Methods("GET")
// Account CRUD Endpoints
// Component CRUD Endpoints
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.GetComponents, middleware.Method("GET"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.CreateComponents, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.UpdateComponent, middleware.Method("PUT"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.GetComponents, middleware.Authenticated(), middleware.Logging())).Methods("GET")
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.CreateComponents, middleware.Authenticated(), middleware.Logging())).Methods("POST")
router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.UpdateComponent, middleware.Authenticated(), middleware.Logging())).Methods("PUT")
// Page CRUD Endpoints
// Query CRUD Endpoints
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/execute", middleware.Chain(api.PostQuery, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.CreateQuery, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.UpdateQuery, middleware.Method("PUT"), middleware.Authenticated(), middleware.Logging()))
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/execute", middleware.Chain(api.PostQuery, middleware.Authenticated(), middleware.Logging())).Methods("POST")
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.CreateQuery, middleware.Authenticated(), middleware.Logging())).Methods("POST")
router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/{id}", middleware.Chain(api.UpdateQuery, middleware.Authenticated(), middleware.Logging())).Methods("PUT")
return router
}

View File

@ -50,7 +50,7 @@ func InitPostgresDb() (datastore DataStore, err error) {
d.DB.DB().SetMaxOpenConns(d.MaxOpenConnections)
fmt.Println("Successfully connected!")
// listTables()
d.DB.LogMode(true)
return &d, nil
}