package restart

import (
	"log"
	"net/http"
	"os"
	"os/signal"
	"sync/atomic"
	"syscall"
	"time"
)

var (
	WaitForHealthCheck = 5 * time.Second

	shutdownFn   func() = nil
	shuttingDown int32  = 0

	sigs = make(chan os.Signal, 1)
)

func init() {
	signal.Notify(sigs, syscall.SIGHUP, syscall.SIGINT)
}

// HealthCheck wraps an otherwise normal health check and automatically responds with 503
// responses during restarts to signal to primary/backup proxies to use the other server.
func HealthCheck(h http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
		// "X-Meh-Restart: check" header is sent by the primary/backup proxy
		// Any other requests to this health-check should behave normally as only
		// meh-restart cares about the restart logic.
		if req.Header.Get("X-Meh-Restart") != "check" || atomic.LoadInt32(&shuttingDown) == 0 {
			h.ServeHTTP(rw, req)
		} else {
			rw.WriteHeader(http.StatusServiceUnavailable)
		}
	})
}

func OnShutdown(fn func()) {
	if shutdownFn != nil {
		panic("restart: OnShutdown should only be called once")
	}
	shutdownFn = fn
	go listenForShutdown()
}

func listenForShutdown() {
	select {
	case <-sigs:
		signal.Stop(sigs)
		atomic.StoreInt32(&shuttingDown, 1)
		log.Println("restart: received shutdown signal. health checks should now 503")
		log.Printf("restart: waiting %v before shutting down so connections stop arriving", WaitForHealthCheck)
		time.Sleep(WaitForHealthCheck)
		shutdownFn()
	}
}
