Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add timeout callback #4142

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions rest/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
timeout time.Duration
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
timeoutCallback handler.TimeoutCallback
chain chain.Chain
middlewares []Middleware
shedder load.Shedder
Expand Down Expand Up @@ -64,7 +65,8 @@
}

func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
verifier func(chain.Chain) chain.Chain,
) chain.Chain {
if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 {
chn = chn.Append(handler.Authorize(fr.jwt.secret,
Expand Down Expand Up @@ -95,7 +97,8 @@
}

func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain.Chain) chain.Chain) error {
route Route, verifier func(chain.Chain) chain.Chain,
) error {
chn := ng.chain
if chn == nil {
chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
Expand Down Expand Up @@ -124,7 +127,8 @@
}

func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
metrics *stat.Metrics) chain.Chain {
metrics *stat.Metrics,
) chain.Chain {
chn := chain.New()

if ng.conf.Middlewares.Trace {
Expand All @@ -148,7 +152,7 @@
chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
}
if ng.conf.Middlewares.Timeout {
chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout), handler.WithTimeoutCallback(ng.timeoutCallback)))
}
if ng.conf.Middlewares.Recover {
chn = chn.Append(handler.RecoverHandler)
Expand Down Expand Up @@ -265,6 +269,10 @@
ng.unsignedCallback = callback
}

func (ng *engine) setTimeoutCallback(callback handler.TimeoutCallback) {
ng.timeoutCallback = callback

Check warning on line 273 in rest/engine.go

View check run for this annotation

Codecov / codecov/patch

rest/engine.go#L272-L273

Added lines #L272 - L273 were not covered by tests
}

func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
if !signature.enabled {
return func(chn chain.Chain) chain.Chain {
Expand Down
54 changes: 42 additions & 12 deletions rest/handler/timeouthandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,40 @@ const (
valueSSE = "text/event-stream"
)

type (
// TimeoutCallback defines the method of timeout callback.
TimeoutCallback func(w http.ResponseWriter, r *http.Request, err error)
// TimeoutOption defines the method to customize a TimeoutOptions.
TimeoutOptions struct {
Callback TimeoutCallback
}

// TimeoutOption defines the method to customize an TimeoutOptions.
TimeoutOption func(opts *TimeoutOptions)
)

var defaultTimeoutCallback TimeoutCallback = func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
}
_, _ = io.WriteString(w, reason)
}

// TimeoutHandler returns the handler with given timeout.
// If client closed request, code 499 will be logged.
// Notice: even if canceled in server side, 499 will be logged as well.
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
func TimeoutHandler(duration time.Duration, opt ...TimeoutOption) func(http.Handler) http.Handler {
var opts TimeoutOptions
for _, o := range opt {
o(&opts)
}

if opts.Callback == nil {
opts.Callback = defaultTimeoutCallback
}

return func(next http.Handler) http.Handler {
if duration <= 0 {
return next
Expand All @@ -40,6 +70,7 @@ func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
return &timeoutHandler{
handler: next,
dt: duration,
cb: opts.Callback,
}
}
}
Expand All @@ -51,10 +82,7 @@ func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
type timeoutHandler struct {
handler http.Handler
dt time.Duration
}

func (h *timeoutHandler) errorBody() string {
return reason
cb TimeoutCallback
}

func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -106,15 +134,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler,
// there isn't any user-defined middleware before TimeoutHandler,
// so we can guarantee that cancelation in biz related code won't come here.
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
}
_, _ = io.WriteString(w, h.errorBody())
h.cb(w, r, err)
})
tw.timedOut = true
}
Expand Down Expand Up @@ -244,3 +267,10 @@ func relevantCaller() runtime.Frame {

return frame
}

// WithTimeoutCallback returns an AuthorizeOption with setting Timeout callback.
func WithTimeoutCallback(callback TimeoutCallback) TimeoutOption {
return func(opts *TimeoutOptions) {
opts.Callback = callback
}
}
16 changes: 15 additions & 1 deletion rest/handler/timeouthandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ func TestTimeout(t *testing.T) {
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}

func TestTimeHandlerCallback(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond, WithTimeoutCallback(func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(486)
}))
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute)
}))

req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, 486, resp.Code)
}

func TestWithinTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Second)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -112,7 +126,7 @@ func TestWithinTimeoutBadCode(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}

func TestWithTimeoutTimedout(t *testing.T) {
func TestWithTimeoutTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 10)
Expand Down
10 changes: 9 additions & 1 deletion rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
// fn lets caller customizing the response.
func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),
origin ...string) RunOption {
origin ...string,
) RunOption {
return func(server *Server) {
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
server.router = newCorsRouter(server.router, middlewareFn, origin...)
Expand Down Expand Up @@ -306,6 +307,13 @@
}
}

// WithTimeoutCallback returns a RunOption that with given timeout callback set.
func WithTimeoutCallback(callback handler.TimeoutCallback) RunOption {
return func(svr *Server) {
svr.ngin.setTimeoutCallback(callback)

Check warning on line 313 in rest/server.go

View check run for this annotation

Codecov / codecov/patch

rest/server.go#L311-L313

Added lines #L311 - L313 were not covered by tests
}
}

func handleError(err error) {
// ErrServerClosed means the server is closed manually
if err == nil || errors.Is(err, http.ErrServerClosed) {
Expand Down