2016-11-16 30 views
-1

我可以在下面的示例代码中看到两个主要问题,但我不知道如何正确解决它们。超时处理程序中的竞争条件

如果超时处理程序未通过errCh获取下一个处理程序已完成或发生错误的信号,它将回复请求的“408请求超时”。

这里的问题是ResponseWriter不安全,可能被多个goroutine使用。超时处理程序在执行下一个处理程序时启动一个新的goroutine。

问题:

  1. 如何防止下一个处理从编写到ResponseWriter当CTX的完成通道超时在超时处理程序。

  2. 如何防止超时处理程序在下一个处理程序正在写入ResponseWriter时回复408状态码,但尚未完成,并且ctx的Done通道在超时处理程序中超时。


package main 

import (
    "context" 
    "fmt" 
    "net/http" 
    "time" 
) 

func main() { 
    http.Handle("/race", handlerFunc(timeoutHandler)) 
    http.ListenAndServe(":8080", nil) 
} 

func timeoutHandler(w http.ResponseWriter, r *http.Request) error { 
    const seconds = 1 
    ctx, cancel := context.WithTimeout(r.Context(), time.Duration(seconds)*time.Second) 
    defer cancel() 

    r = r.WithContext(ctx) 

    errCh := make(chan error, 1) 
    go func() { 
    // w is not safe for concurrent use by multiple goroutines 
    errCh <- nextHandler(w, r) 
    }() 

    select { 
    case err := <-errCh: 
    return err 
    case <-ctx.Done(): 
    // w is not safe for concurrent use by multiple goroutines 
    http.Error(w, "Request timeout", 408) 
    return nil 
    } 
} 

func nextHandler(w http.ResponseWriter, r *http.Request) error { 
    // just for fun to simulate a better race condition 
    const seconds = 1 
    time.Sleep(time.Duration(seconds) * time.Second) 
    fmt.Fprint(w, "nextHandler") 
    return nil 
} 

type handlerFunc func(w http.ResponseWriter, r *http.Request) error 

func (fn handlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { 
    if err := fn(w, r); err != nil { 
    http.Error(w, "Server error", 500) 
    } 
} 
+2

如何通过其他的东西比原来的'ResponseWriter'设置为'nextHandler()'?然后您必须将结果复制回'<-errCh'情况下的原始'ResponseWriter'。 –

+0

无论要写入ResponseWriter,还需要对超时负责。你正在为nextHandler提供一个超时上下文,所以应该能够处理超时本身。一般来说,如果只有一个处理程序负责编写响应,则会更容易。 – JimB

回答

0

这里是一个可能的解决方案,它基于@安迪的评论。

responseRecorder将被传递给nextHandler,而记录的响应将被复制回客户端:

func timeoutHandler(w http.ResponseWriter, r *http.Request) error { 
    const seconds = 1 
    ctx, cancel := context.WithTimeout(r.Context(), 
     time.Duration(seconds)*time.Second) 
    defer cancel() 

    r = r.WithContext(ctx) 

    errCh := make(chan error, 1) 
    w2 := newResponseRecorder() 
    go func() { 
     errCh <- nextHandler(w2, r) 
    }() 

    select { 
    case err := <-errCh: 
     if err != nil { 
      return err 
     } 

     w2.cloneHeader(w.Header()) 
     w.WriteHeader(w2.status) 
     w.Write(w2.buf.Bytes()) 
     return nil 
    case <-ctx.Done(): 
     http.Error(w, "Request timeout", 408) 
     return nil 
    } 
} 

这里是responseRecorder

type responseRecorder struct { 
    http.ResponseWriter 
    header http.Header 
    buf *bytes.Buffer 
    status int 
} 

func newResponseRecorder() *responseRecorder { 
    return &responseRecorder{ 
     header: http.Header{}, 
     buf: &bytes.Buffer{}, 
    } 
} 

func (w *responseRecorder) Header() http.Header { 
    return w.header 
} 

func (w *responseRecorder) cloneHeader(dst http.Header) { 
    for k, v := range w.header { 
     tmp := make([]string, len(v)) 
     copy(tmp, v) 
     dst[k] = tmp 
    } 
} 

func (w *responseRecorder) Write(data []byte) (int, error) { 
    if w.status == 0 { 
     w.WriteHeader(http.StatusOK) 
    } 
    return w.buf.Write(data) 
} 

func (w *responseRecorder) WriteHeader(status int) { 
    w.status = status 
}