milvus/internal/distributed/proxy/httpserver/timeout_middleware.go

201 lines
4.2 KiB
Go

package httpserver
import (
"bytes"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func defaultResponse(c *gin.Context) {
c.String(http.StatusRequestTimeout, "timeout")
}
// BufferPool represents a pool of buffers.
type BufferPool struct {
pool sync.Pool
}
// Get returns a buffer from the buffer pool.
// If the pool is empty, a new buffer is created and returned.
func (p *BufferPool) Get() *bytes.Buffer {
buf := p.pool.Get()
if buf == nil {
return &bytes.Buffer{}
}
return buf.(*bytes.Buffer)
}
// Put adds a buffer back to the pool.
func (p *BufferPool) Put(buf *bytes.Buffer) {
p.pool.Put(buf)
}
// Timeout struct
type Timeout struct {
timeout time.Duration
handler gin.HandlerFunc
response gin.HandlerFunc
}
// Writer is a writer with memory buffer
type Writer struct {
gin.ResponseWriter
body *bytes.Buffer
headers http.Header
mu sync.Mutex
timeout bool
wroteHeaders bool
code int
}
// NewWriter will return a timeout.Writer pointer
func NewWriter(w gin.ResponseWriter, buf *bytes.Buffer) *Writer {
return &Writer{ResponseWriter: w, body: buf, headers: make(http.Header)}
}
// Write will write data to response body
func (w *Writer) Write(data []byte) (int, error) {
if w.timeout || w.body == nil {
return 0, nil
}
w.mu.Lock()
defer w.mu.Unlock()
return w.body.Write(data)
}
// WriteHeader sends an HTTP response header with the provided status code.
// If the response writer has already written headers or if a timeout has occurred,
// this method does nothing.
func (w *Writer) WriteHeader(code int) {
if w.timeout || w.wroteHeaders {
return
}
// gin is using -1 to skip writing the status code
// see https://github.com/gin-gonic/gin/blob/a0acf1df2814fcd828cb2d7128f2f4e2136d3fac/response_writer.go#L61
if code == -1 {
return
}
checkWriteHeaderCode(code)
w.mu.Lock()
defer w.mu.Unlock()
w.writeHeader(code)
w.ResponseWriter.WriteHeader(code)
}
func (w *Writer) writeHeader(code int) {
w.wroteHeaders = true
w.code = code
}
// Header will get response headers
func (w *Writer) Header() http.Header {
return w.headers
}
// WriteString will write string to response body
func (w *Writer) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
// FreeBuffer will release buffer pointer
func (w *Writer) FreeBuffer() {
// if not reset body,old bytes will put in bufPool
w.body.Reset()
w.body = nil
}
// Status we must override Status func here,
// or the http status code returned by gin.Context.Writer.Status()
// will always be 200 in other custom gin middlewares.
func (w *Writer) Status() int {
if w.code == 0 || w.timeout {
return w.ResponseWriter.Status()
}
return w.code
}
func checkWriteHeaderCode(code int) {
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid http status code: %d", code))
}
}
func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
t := &Timeout{
timeout: HTTPDefaultTimeout,
handler: handler,
response: defaultResponse,
}
bufPool := &BufferPool{}
return func(c *gin.Context) {
timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(HTTPHeaderRequestTimeout), 10, 64)
if err == nil {
t.timeout = time.Duration(timeoutSecond) * time.Second
}
finish := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
w := c.Writer
buffer := bufPool.Get()
tw := NewWriter(w, buffer)
c.Writer = tw
buffer.Reset()
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
t.handler(c)
finish <- struct{}{}
}()
select {
case p := <-panicChan:
tw.FreeBuffer()
c.Writer = w
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
panic(p)
case <-finish:
c.Next()
tw.mu.Lock()
defer tw.mu.Unlock()
dst := tw.ResponseWriter.Header()
for k, vv := range tw.Header() {
dst[k] = vv
}
if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
panic(err)
}
tw.FreeBuffer()
bufPool.Put(buffer)
case <-time.After(t.timeout):
c.Abort()
tw.mu.Lock()
defer tw.mu.Unlock()
tw.timeout = true
tw.FreeBuffer()
bufPool.Put(buffer)
c.Writer = w
t.response(c)
c.Writer = tw
}
}
}