mirror of https://github.com/milvus-io/milvus.git
201 lines
4.2 KiB
Go
201 lines
4.2 KiB
Go
package httpserver
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func defaultResponse(c *gin.Context) {
|
|
c.String(http.StatusRequestTimeout, "timeout")
|
|
}
|
|
|
|
// BufferPool represents a pool of buffers.
|
|
type BufferPool struct {
|
|
pool sync.Pool
|
|
}
|
|
|
|
// Get returns a buffer from the buffer pool.
|
|
// If the pool is empty, a new buffer is created and returned.
|
|
func (p *BufferPool) Get() *bytes.Buffer {
|
|
buf := p.pool.Get()
|
|
if buf == nil {
|
|
return &bytes.Buffer{}
|
|
}
|
|
return buf.(*bytes.Buffer)
|
|
}
|
|
|
|
// Put adds a buffer back to the pool.
|
|
func (p *BufferPool) Put(buf *bytes.Buffer) {
|
|
p.pool.Put(buf)
|
|
}
|
|
|
|
// Timeout struct
|
|
type Timeout struct {
|
|
timeout time.Duration
|
|
handler gin.HandlerFunc
|
|
response gin.HandlerFunc
|
|
}
|
|
|
|
// Writer is a writer with memory buffer
|
|
type Writer struct {
|
|
gin.ResponseWriter
|
|
body *bytes.Buffer
|
|
headers http.Header
|
|
mu sync.Mutex
|
|
timeout bool
|
|
wroteHeaders bool
|
|
code int
|
|
}
|
|
|
|
// NewWriter will return a timeout.Writer pointer
|
|
func NewWriter(w gin.ResponseWriter, buf *bytes.Buffer) *Writer {
|
|
return &Writer{ResponseWriter: w, body: buf, headers: make(http.Header)}
|
|
}
|
|
|
|
// Write will write data to response body
|
|
func (w *Writer) Write(data []byte) (int, error) {
|
|
if w.timeout || w.body == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
return w.body.Write(data)
|
|
}
|
|
|
|
// WriteHeader sends an HTTP response header with the provided status code.
|
|
// If the response writer has already written headers or if a timeout has occurred,
|
|
// this method does nothing.
|
|
func (w *Writer) WriteHeader(code int) {
|
|
if w.timeout || w.wroteHeaders {
|
|
return
|
|
}
|
|
|
|
// gin is using -1 to skip writing the status code
|
|
// see https://github.com/gin-gonic/gin/blob/a0acf1df2814fcd828cb2d7128f2f4e2136d3fac/response_writer.go#L61
|
|
if code == -1 {
|
|
return
|
|
}
|
|
|
|
checkWriteHeaderCode(code)
|
|
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
w.writeHeader(code)
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *Writer) writeHeader(code int) {
|
|
w.wroteHeaders = true
|
|
w.code = code
|
|
}
|
|
|
|
// Header will get response headers
|
|
func (w *Writer) Header() http.Header {
|
|
return w.headers
|
|
}
|
|
|
|
// WriteString will write string to response body
|
|
func (w *Writer) WriteString(s string) (int, error) {
|
|
return w.Write([]byte(s))
|
|
}
|
|
|
|
// FreeBuffer will release buffer pointer
|
|
func (w *Writer) FreeBuffer() {
|
|
// if not reset body,old bytes will put in bufPool
|
|
w.body.Reset()
|
|
w.body = nil
|
|
}
|
|
|
|
// Status we must override Status func here,
|
|
// or the http status code returned by gin.Context.Writer.Status()
|
|
// will always be 200 in other custom gin middlewares.
|
|
func (w *Writer) Status() int {
|
|
if w.code == 0 || w.timeout {
|
|
return w.ResponseWriter.Status()
|
|
}
|
|
return w.code
|
|
}
|
|
|
|
func checkWriteHeaderCode(code int) {
|
|
if code < 100 || code > 999 {
|
|
panic(fmt.Sprintf("invalid http status code: %d", code))
|
|
}
|
|
}
|
|
|
|
func timeoutMiddleware(handler gin.HandlerFunc) gin.HandlerFunc {
|
|
t := &Timeout{
|
|
timeout: HTTPDefaultTimeout,
|
|
handler: handler,
|
|
response: defaultResponse,
|
|
}
|
|
bufPool := &BufferPool{}
|
|
return func(c *gin.Context) {
|
|
timeoutSecond, err := strconv.ParseInt(c.Request.Header.Get(HTTPHeaderRequestTimeout), 10, 64)
|
|
if err == nil {
|
|
t.timeout = time.Duration(timeoutSecond) * time.Second
|
|
}
|
|
finish := make(chan struct{}, 1)
|
|
panicChan := make(chan interface{}, 1)
|
|
|
|
w := c.Writer
|
|
buffer := bufPool.Get()
|
|
tw := NewWriter(w, buffer)
|
|
c.Writer = tw
|
|
buffer.Reset()
|
|
|
|
go func() {
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
panicChan <- p
|
|
}
|
|
}()
|
|
t.handler(c)
|
|
finish <- struct{}{}
|
|
}()
|
|
|
|
select {
|
|
case p := <-panicChan:
|
|
tw.FreeBuffer()
|
|
c.Writer = w
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{HTTPReturnCode: http.StatusInternalServerError})
|
|
panic(p)
|
|
|
|
case <-finish:
|
|
c.Next()
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
dst := tw.ResponseWriter.Header()
|
|
for k, vv := range tw.Header() {
|
|
dst[k] = vv
|
|
}
|
|
|
|
if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
|
|
panic(err)
|
|
}
|
|
tw.FreeBuffer()
|
|
bufPool.Put(buffer)
|
|
|
|
case <-time.After(t.timeout):
|
|
c.Abort()
|
|
tw.mu.Lock()
|
|
defer tw.mu.Unlock()
|
|
tw.timeout = true
|
|
tw.FreeBuffer()
|
|
bufPool.Put(buffer)
|
|
|
|
c.Writer = w
|
|
t.response(c)
|
|
c.Writer = tw
|
|
}
|
|
}
|
|
}
|