diff --git a/pkg/httputils/httputils.go b/pkg/httputils/httputils.go index 4db199f0..70a1c122 100644 --- a/pkg/httputils/httputils.go +++ b/pkg/httputils/httputils.go @@ -21,9 +21,12 @@ var ( _ swagger.Provider = (&app{}).Swagger ) +// Pinger describes a function to check liveness of app +type Pinger = func() error + // App of package type App interface { - Health(http.Handler) App + Health(Pinger) App Middleware(model.Middleware) App ListenAndServe(http.Handler) (*http.Server, <-chan error) ListenServeWait(http.Handler) @@ -49,7 +52,7 @@ type app struct { shutdown bool middlewares []model.Middleware - health http.Handler + pingers []Pinger } // Flags adds flags for configuring package @@ -79,8 +82,8 @@ func New(config Config) App { okStatus: *config.okStatus, graceDuration: graceDuration, - health: HealthHandler(*config.okStatus), middlewares: make([]model.Middleware, 0), + pingers: make([]Pinger, 0), } } @@ -102,15 +105,29 @@ func versionHandler() http.Handler { }) } -// HealthHandler for dealing with state of app -func HealthHandler(okStatus int) http.Handler { +// Health set health http handler +func (a *app) healthHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.WriteHeader(http.StatusMethodNotAllowed) return } - w.WriteHeader(okStatus) + if a.shutdown { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + for _, pinger := range a.pingers { + if err := pinger(); err != nil { + logger.Error("unable to ping: %s", err) + + w.WriteHeader(http.StatusServiceUnavailable) + return + } + } + + w.WriteHeader(a.okStatus) }) } @@ -125,16 +142,16 @@ func ChainMiddlewares(handler http.Handler, middlewares ...model.Middleware) htt return result } -// Health set health http handler -func (a *app) Health(health http.Handler) App { - a.health = health +// Middleware add given middleware to list +func (a *app) Middleware(middleware model.Middleware) App { + a.middlewares = append(a.middlewares, middleware) return a } -// Middleware add given middleware to list -func (a *app) Middleware(middleware model.Middleware) App { - a.middlewares = append(a.middlewares, middleware) +// Health add given pinger to list +func (a *app) Health(pinger Pinger) App { + a.pingers = append(a.pingers, pinger) return a } @@ -173,6 +190,7 @@ func (a *app) Swagger() (swagger.Configuration, error) { // ListenAndServe starts server func (a *app) ListenAndServe(handler http.Handler) (*http.Server, <-chan error) { versionHandler := versionHandler() + healthHandler := a.healthHandler() defaultHandler := ChainMiddlewares(handler, a.middlewares...) httpServer := &http.Server{ @@ -180,11 +198,7 @@ func (a *app) ListenAndServe(handler http.Handler) (*http.Server, <-chan error) Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/health": - if a.shutdown { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - a.health.ServeHTTP(w, r) + healthHandler.ServeHTTP(w, r) case "/version": versionHandler.ServeHTTP(w, r) diff --git a/pkg/httputils/httputils_test.go b/pkg/httputils/httputils_test.go index dd67f3a0..6029e0c2 100644 --- a/pkg/httputils/httputils_test.go +++ b/pkg/httputils/httputils_test.go @@ -89,43 +89,6 @@ func TestVersionHandler(t *testing.T) { } } -func TestHealthHandler(t *testing.T) { - var cases = []struct { - intention string - request *http.Request - want string - wantStatus int - }{ - { - "simple", - httptest.NewRequest(http.MethodGet, "/", nil), - "", - http.StatusNoContent, - }, - { - "invalid method", - httptest.NewRequest(http.MethodOptions, "/", nil), - "", - http.StatusMethodNotAllowed, - }, - } - - for _, testCase := range cases { - t.Run(testCase.intention, func(t *testing.T) { - writer := httptest.NewRecorder() - HealthHandler(http.StatusNoContent).ServeHTTP(writer, testCase.request) - - if result := writer.Code; result != testCase.wantStatus { - t.Errorf("HealthHandler = %d, want %d", result, testCase.wantStatus) - } - - if result, _ := request.ReadBodyResponse(writer.Result()); string(result) != testCase.want { - t.Errorf("HealthHandler = `%s`, want `%s`", string(result), testCase.want) - } - }) - } -} - func TestChainMiddlewares(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("handler"))