From 49b7051cead05e56ae671ece3849fec96db40743 Mon Sep 17 00:00:00 2001 From: Arpit Mohan Date: Sat, 16 Mar 2019 17:05:20 +0530 Subject: [PATCH] Moving the method comparison in gorilla mux to the end and not in the chain This allows mux to compare and match the requests based on the HTTP method as well. If we compare the methods in the middleware chain, then mux will try to redirect the request to the first match it finds. --- app/server/api/middleware/middleware.go | 21 --------------------- app/server/api/query.go | 7 +++++++ app/server/server.go | 22 +++++++++++----------- app/server/storage/postgres.go | 2 +- 4 files changed, 19 insertions(+), 33 deletions(-) diff --git a/app/server/api/middleware/middleware.go b/app/server/api/middleware/middleware.go index 5507f88915..4f3e465d54 100644 --- a/app/server/api/middleware/middleware.go +++ b/app/server/api/middleware/middleware.go @@ -29,27 +29,6 @@ func Logging() Middleware { } } -// Method ensures that url can only be requested with a specific method, else returns a 400 Bad Request -func Method(m string) Middleware { - - // Create a new Middleware - return func(f http.HandlerFunc) http.HandlerFunc { - - // Define the http.HandlerFunc - return func(w http.ResponseWriter, r *http.Request) { - - // Do middleware things - if r.Method != m { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - - // Call the next middleware/handler in chain - f(w, r) - } - } -} - func Authenticated() Middleware { // Create a new Middleware diff --git a/app/server/api/query.go b/app/server/api/query.go index 4a605b24c5..bee620ff0f 100644 --- a/app/server/api/query.go +++ b/app/server/api/query.go @@ -3,7 +3,11 @@ package api import ( "encoding/json" "fmt" + "log" "net/http" + "strconv" + + "github.com/gorilla/mux" "gitlab.com/mobtools/internal-tools-server/models" "gitlab.com/mobtools/internal-tools-server/services" @@ -64,10 +68,13 @@ func UpdateQuery(w http.ResponseWriter, r *http.Request) { queryBody := models.Query{} err := json.NewDecoder(r.Body).Decode(&queryBody) if err != nil { + log.Printf("Got error when decoding the queryBody. %s", err.Error()) HandleAPIError(w, r, err) return } + queryBody.ID, _ = strconv.ParseInt(mux.Vars(r)["id"], 10, 64) + log.Printf("Got query.ID as %d", queryBody.ID) queryBody, err = services.UpdateQuery(queryBody) if err != nil { HandleAPIError(w, r, err) diff --git a/app/server/server.go b/app/server/server.go index 612d6d2f93..ebd1414a44 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -49,25 +49,25 @@ func intializeServer() *mux.Router { } // Auth Endpoints - router.HandleFunc(url.LoginURL, middleware.Chain(api.Login, middleware.Method("GET"), middleware.Logging())) - router.HandleFunc(url.AuthURL, middleware.Chain(api.InitiateAuth, middleware.Method("GET"), middleware.Logging())) - router.HandleFunc(url.AuthCallbackURL, middleware.Chain(api.AuthCallback, middleware.Method("GET"), middleware.Logging())) - router.HandleFunc(url.LogoutURL, middleware.Chain(api.Logout, middleware.Method("GET"), middleware.Logging())) - router.HandleFunc(url.ProfileURL, middleware.Chain(api.GetUserProfile, middleware.Method("GET"), middleware.Logging())) + router.HandleFunc(url.LoginURL, middleware.Chain(api.Login, middleware.Logging())).Methods("GET") + router.HandleFunc(url.AuthURL, middleware.Chain(api.InitiateAuth, middleware.Logging())).Methods("GET") + router.HandleFunc(url.AuthCallbackURL, middleware.Chain(api.AuthCallback, middleware.Logging())).Methods("GET") + router.HandleFunc(url.LogoutURL, middleware.Chain(api.Logout, middleware.Logging())).Methods("GET") + router.HandleFunc(url.ProfileURL, middleware.Chain(api.GetUserProfile, middleware.Logging())).Methods("GET") // Account CRUD Endpoints // Component CRUD Endpoints - router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.GetComponents, middleware.Method("GET"), middleware.Authenticated(), middleware.Logging())) - router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.CreateComponents, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging())) - router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.UpdateComponent, middleware.Method("PUT"), middleware.Authenticated(), middleware.Logging())) + router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.GetComponents, middleware.Authenticated(), middleware.Logging())).Methods("GET") + router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.CreateComponents, middleware.Authenticated(), middleware.Logging())).Methods("POST") + router.HandleFunc(baseAPIURL+apiVersion+url.ComponentURL, middleware.Chain(api.UpdateComponent, middleware.Authenticated(), middleware.Logging())).Methods("PUT") // Page CRUD Endpoints // Query CRUD Endpoints - router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/execute", middleware.Chain(api.PostQuery, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging())) - router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.CreateQuery, middleware.Method("POST"), middleware.Authenticated(), middleware.Logging())) - router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.UpdateQuery, middleware.Method("PUT"), middleware.Authenticated(), middleware.Logging())) + router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/execute", middleware.Chain(api.PostQuery, middleware.Authenticated(), middleware.Logging())).Methods("POST") + router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL, middleware.Chain(api.CreateQuery, middleware.Authenticated(), middleware.Logging())).Methods("POST") + router.HandleFunc(baseAPIURL+apiVersion+url.QueryURL+"/{id}", middleware.Chain(api.UpdateQuery, middleware.Authenticated(), middleware.Logging())).Methods("PUT") return router } diff --git a/app/server/storage/postgres.go b/app/server/storage/postgres.go index 1588c6ee25..8429324861 100644 --- a/app/server/storage/postgres.go +++ b/app/server/storage/postgres.go @@ -50,7 +50,7 @@ func InitPostgresDb() (datastore DataStore, err error) { d.DB.DB().SetMaxOpenConns(d.MaxOpenConnections) fmt.Println("Successfully connected!") - // listTables() + d.DB.LogMode(true) return &d, nil }