Skip to content

Commit a44d608

Browse files
Add support for OnDisconnect when WS connection is closed
v4 only
1 parent c0c98bf commit a44d608

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

Diff for: engineio/server.v4.go

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ func (v4 *serverV4) serveTransport(w http.ResponseWriter, r *http.Request) (tran
111111
ctx = v4.sessions.WithCancel(ctx)
112112
ctx = v4.sessions.WithInterval(ctx, v4.pingInterval)
113113
ctx = v4.sessions.WithTimeout(ctx, v4.pingTimeout)
114+
ctx = v4.sessions.WithDisconnectOnClose(ctx)
114115

115116
go func() {
116117
v4.transportRunError <- upgrade.transport.Run(w, r.WithContext(ctx), append(v4.eto, opts...)...)

Diff for: engineio/session/manage.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ const (
99
SessionIntervalKey sessionCtxKey = "interval"
1010
SessionExtendTimeoutKey sessionCtxKey = "timeout-extend"
1111

12-
SessionCloseChannelKey sessionCtxKey = "cancel-channel"
13-
SessionCloseFunctionKey sessionCtxKey = "cancel-function"
12+
SessionCloseChannelKey sessionCtxKey = "cancel-channel"
13+
SessionCloseFunctionKey sessionCtxKey = "cancel-function"
14+
SessionDisconnectFunctionKey sessionCtxKey = "cancel-disconnect"
1415
)
1516

1617
type (

Diff for: engineio/sessions.go

+18
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type transportSessions interface {
2525
WithCancel(ctx context.Context) context.Context
2626
WithTimeout(ctx context.Context, d time.Duration) context.Context
2727
WithInterval(ctx context.Context, d time.Duration) context.Context
28+
WithDisconnectOnClose(ctx context.Context) context.Context
2829
}
2930

3031
type sessions struct {
@@ -88,6 +89,23 @@ type lifecycle struct {
8889
removeTransport func(SessionID)
8990
}
9091

92+
func (c *lifecycle) WithDisconnectOnClose(ctx context.Context) context.Context {
93+
sessionID, ok := ctx.Value(ctxSessionID).(SessionID)
94+
if !ok {
95+
// there is no session to attach the context to
96+
return ctx
97+
}
98+
ctx = context.WithValue(ctx, eios.SessionDisconnectFunctionKey, func() func() {
99+
return func() {
100+
c.removeSession(sessionID)
101+
if c.removeTransport != nil {
102+
c.removeTransport(sessionID)
103+
}
104+
}
105+
})
106+
return ctx
107+
}
108+
91109
func (c *lifecycle) WithCancel(ctx context.Context) context.Context {
92110
sessionID, ok := ctx.Value(ctxSessionID).(SessionID)
93111
if !ok {

Diff for: engineio/transport/transport.websocket.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ func (t *WebsocketTransport) With(opts ...Option) {
6666

6767
func (t *WebsocketTransport) InnerTransport() *Transport { return t.Transport }
6868

69+
func (t *WebsocketTransport) ConnClose(code ws.StatusCode, reason string, ctx context.Context) error {
70+
err := t.conn.Close(code, reason)
71+
if done, ok := ctx.Value(eios.SessionDisconnectFunctionKey).(func() func()); ok {
72+
if cleanup := done(); cleanup != nil {
73+
cleanup()
74+
}
75+
}
76+
return err
77+
}
78+
6979
func (t *WebsocketTransport) Run(w http.ResponseWriter, r *http.Request, opts ...Option) (err error) {
7080
t.With(opts...)
7181

@@ -95,7 +105,7 @@ func (t *WebsocketTransport) Run(w http.ResponseWriter, r *http.Request, opts ..
95105
grp.Go(func() error { return t.outgoing(r.WithContext(ctx)) })
96106

97107
err = grp.Wait()
98-
t.conn.Close(ws.StatusNormalClosure, "done")
108+
t.ConnClose(ws.StatusNormalClosure, "done", ctx)
99109
return err
100110
}
101111

@@ -161,7 +171,7 @@ func (t *WebsocketTransport) incoming(ctx context.Context) (err error) {
161171

162172
var done func()
163173
var reason string
164-
defer func() { t.conn.Close(ws.StatusNormalClosure, reason) }()
174+
defer func() { t.ConnClose(ws.StatusNormalClosure, reason, ctx) }()
165175

166176
var start = time.Now()
167177
Write:
@@ -259,7 +269,7 @@ func (t *WebsocketTransport) outgoing(r *http.Request) (err error) {
259269
}
260270

261271
var unbuffered = new(sync.WaitGroup)
262-
defer t.conn.Close(ws.StatusNormalClosure, "read")
272+
defer t.ConnClose(ws.StatusNormalClosure, "read", ctx)
263273

264274
for {
265275
if !t.buffered {
@@ -309,7 +319,7 @@ func (t *WebsocketTransport) outgoing(r *http.Request) (err error) {
309319
}
310320
}
311321
t.conn.CloseRead(ctx)
312-
t.conn.Close(ws.StatusNormalClosure, "cross origin WebSocket accepted")
322+
t.ConnClose(ws.StatusNormalClosure, "cross origin WebSocket accepted", ctx)
313323
return nil
314324
case eiop.PingPacket:
315325
cw, err := t.conn.Writer(ctx, ws.MessageText)

0 commit comments

Comments
 (0)