package http

import (
	"context"
	"net"
	"net/http"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"time"

	"github.com/influxdata/platform/logger"
	"go.uber.org/zap"
)

// DefaultShutdownTimeout is the default timeout for shutting down the http server.
const DefaultShutdownTimeout = 20 * time.Second

// Server is an abstraction around the http.Server that handles a server process.
// It manages the full lifecycle of a server by serving a handler on a socket.
// If signals have been registered, it will attempt to terminate the server using
// Shutdown if a signal is received and will force a shutdown if a second signal
// is received.
type Server struct {
	ShutdownTimeout time.Duration

	srv     *http.Server
	signals map[os.Signal]struct{}
	logger  *zap.Logger
	wg      sync.WaitGroup
}

// NewServer returns a new server struct that can be used.
func NewServer(handler http.Handler, logger *zap.Logger) *Server {
	if logger == nil {
		logger = zap.NewNop()
	}
	return &Server{
		ShutdownTimeout: DefaultShutdownTimeout,
		srv: &http.Server{
			Handler: handler,
		},
		logger: logger,
	}
}

// Serve will run the server using the listener to accept connections.
func (s *Server) Serve(listener net.Listener) error {
	// When we return, wait for all pending goroutines to finish.
	defer s.wg.Wait()

	signalCh, cancel := s.notifyOnSignals()
	defer cancel()

	errCh := s.serve(listener)
	select {
	case err := <-errCh:
		// The server has failed and reported an error.
		return err
	case <-signalCh:
		// We have received an interrupt. Signal the shutdown process.
		return s.shutdown(signalCh)
	}
}

func (s *Server) serve(listener net.Listener) <-chan error {
	s.wg.Add(1)
	errCh := make(chan error, 1)
	go func() {
		defer s.wg.Done()
		if err := s.srv.Serve(listener); err != nil {
			errCh <- err
		}
		close(errCh)
	}()
	return errCh
}

func (s *Server) shutdown(signalCh <-chan os.Signal) error {
	s.logger.Info("Shutting down server", logger.DurationLiteral("timeout", s.ShutdownTimeout))

	// The shutdown needs to succeed in 20 seconds or less.
	ctx, cancel := context.WithTimeout(context.Background(), s.ShutdownTimeout)
	defer cancel()

	// Wait for another signal to cancel the shutdown.
	done := make(chan struct{})
	defer close(done)

	s.wg.Add(1)
	go func() {
		defer s.wg.Done()
		select {
		case <-signalCh:
			s.logger.Info("Initializing hard shutdown")
			cancel()
		case <-done:
		}
	}()
	return s.srv.Shutdown(ctx)
}

// ListenForSignals registers the the server to listen for the given signals
// to shutdown the server. The signals are not captured until Serve is called.
func (s *Server) ListenForSignals(signals ...os.Signal) {
	if s.signals == nil {
		s.signals = make(map[os.Signal]struct{})
	}

	for _, sig := range signals {
		s.signals[sig] = struct{}{}
	}
}

func (s *Server) notifyOnSignals() (_ <-chan os.Signal, cancel func()) {
	if len(s.signals) == 0 {
		return nil, func() {}
	}

	// Retrieve which signals we want to be notified on.
	signals := make([]os.Signal, 0, len(s.signals))
	for sig := range s.signals {
		signals = append(signals, sig)
	}

	// Create the signal channel and mark ourselves to be notified
	// of signals. Allow up to two signals for each signal type we catch.
	signalCh := make(chan os.Signal, len(signals)*2)
	signal.Notify(signalCh, signals...)
	return signalCh, func() { signal.Stop(signalCh) }
}

// ListenAndServe is a convenience method for opening a listener using the address
// and then serving the handler on that address. This method sets up the typical
// signal handlers.
func ListenAndServe(addr string, handler http.Handler, logger *zap.Logger) error {
	l, err := net.Listen("tcp", addr)
	if err != nil {
		return err
	}

	server := NewServer(handler, logger)
	server.ListenForSignals(os.Interrupt, syscall.SIGTERM)
	return server.Serve(l)
}