Skip to content

Commit 8bde34c

Browse files
author
Pushkar N Kulkarni
authored
permessage-deflate: implement context takeover (#26)
* permessage-deflate: implement context takeover * Use negative windowBits values for the deflater The compression format of the output produced by zlibs deflater is decided by the windowBits values it is configured with. We used positive windowBits values which, according to the zlib manual, emits the zlib stream starting with a zlib header. WebSockets use raw deflate streams, and not zlib streams. Consequently, we had to strip the zlib header. The deflater can be configured to emit a raw deflate format by using negative values for the windowBits. With this in place, there's no zlib header in the output stream. * Handle the special case of LZ77 window size of 256 The server_max_window_bits value sent by the client in the negotiation offer is used to configure the deflater on the server. This acceptance is also notified to the client by returning the server_max_window_bits header in the negotiation response. A special case arises when we receive server_max_window_bits = 8. There's an open zlib issue with the window size of 256 (windoBits=8). For zlib streams, the library silently changes the windowBits to 9 and informs the inflater via the zlib header. However, this apparent hack is not feasible for raw deflate streams, which are used in WebSockets. To take care of this fact, zlib has been patched to reject a window size of 256, via a deflateInit2() failure. Details here: madler/zlib#171 As a result, we need to handle a window size of 256 (windowBits=8) as a special case. We silently change it to 9 and inform the client via a suitable header in the negotiation response. * Close deflater and inflater streams on connection close Without context takeover, the deflater and inflater are initialized and closed on every message. However, with context takeover, the deflater and inflater persist for the entire lifetime of the connection. They need to be closed on a connection close. The deflater is on a ChannelOutboundHandler and must be closed when the WebSocketConnections issues a channel.close(). The inflater is on a ChannelInboundHandler, it must be closed when we receive a `closeConnection` frame from the remote peer.
1 parent 4f784fc commit 8bde34c

7 files changed

+193
-58
lines changed

Sources/KituraWebSocket/PermessageDeflate.swift

+55-19
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,44 @@ class PermessageDeflate: WebSocketProtocolExtension {
2626
guard header.hasPrefix("permessage-deflate") else { return [] }
2727
var deflaterMaxWindowBits: Int32 = 15
2828
var inflaterMaxWindowBits: Int32 = 15
29-
//TODO: change these defaults to false after implementing context takeover
30-
var clientNoContextTakeover = true
31-
var serverNoContextTakeover = true
29+
var clientNoContextTakeover = false
30+
var serverNoContextTakeover = false
3231

3332
// Four parameters to handle:
3433
// * server_max_window_bits: the LZ77 sliding window size used by the server for compression
3534
// * client_max_window_bits: the LZ77 sliding window size used by the server for decompression
3635
// * server_no_context_takeover: prevent the server from using context-takeover
3736
// * client_no_context_takeover: prevent the client from using context-takeover
3837
for parameter in header.components(separatedBy: "; ") {
39-
// If we receieved a valid value for server_max_window_bits, configure the deflater to use it
38+
// If we receieved a valid value for server_max_window_bits, use it to configure the deflater
4039
if parameter.hasPrefix("server_max_window_bits") {
4140
let maxWindowBits = parameter.components(separatedBy: "=")
4241
guard maxWindowBits.count == 2 else { continue }
43-
if let mwBits = Int32(maxWindowBits[1]) {
44-
if mwBits >= 8 && mwBits <= 15 {
45-
deflaterMaxWindowBits = mwBits
46-
}
42+
guard let mwBits = Int32(maxWindowBits[1]) else { continue }
43+
if mwBits >= 8 && mwBits <= 15 {
44+
// We received a valid value. However there's a special case here:
45+
//
46+
// There's an open zlib issue which does not set the window size
47+
// to 256 (windowBits=8). For windowBits=8, zlib silently changes the
48+
// value to 9. However, this apparent hack works only with zlib streams.
49+
// WebSockets use raw deflate streams. For raw deflate streams, zlib has been
50+
// patched to ignore the windowBits value 8.
51+
// More details here: https://github.com/madler/zlib/issues/171
52+
//
53+
// So, if the server requested for server_max_window_bits=8, we are
54+
// going to use server_max_window_bits=9 instead and notify this in
55+
// our negotiation response too.
56+
deflaterMaxWindowBits = mwBits == 8 ? 9 : mwBits
4757
}
4858
}
4959

50-
// If we receieved a valid value for server_max_window_bits, configure the inflater to use it
60+
// If we received a valid client_max_window_bits value, use it to configure the inflater
5161
if parameter.hasPrefix("client_max_window_bits") {
5262
let maxWindowBits = parameter.components(separatedBy: "=")
5363
guard maxWindowBits.count == 2 else { continue }
54-
if let mwBits = Int32(maxWindowBits[1]) {
55-
if mwBits >= 8 && mwBits <= 15 {
56-
inflaterMaxWindowBits = mwBits
57-
}
64+
guard let mwBits = Int32(maxWindowBits[1]) else { continue }
65+
if mwBits >= 8 && mwBits <= 15 {
66+
inflaterMaxWindowBits = mwBits
5867
}
5968
}
6069

@@ -66,7 +75,6 @@ class PermessageDeflate: WebSocketProtocolExtension {
6675
serverNoContextTakeover = true
6776
}
6877
}
69-
7078
return [PermessageDeflateCompressor(maxWindowBits: deflaterMaxWindowBits, noContextTakeOver: serverNoContextTakeover),
7179
PermessageDeflateDecompressor(maxWindowBits: inflaterMaxWindowBits, noContextTakeOver: clientNoContextTakeover)]
7280
}
@@ -81,16 +89,44 @@ class PermessageDeflate: WebSocketProtocolExtension {
8189

8290
for parameter in header.components(separatedBy: "; ") {
8391
if parameter == "client_no_context_takeover" {
84-
//TODO: include client_no_context_takeover in the response
92+
response.append("; client_no_context_takeover")
8593
}
8694

8795
if parameter == "server_no_context_takeover" {
88-
//TODO: include server_no_context_takeover in the response
96+
response.append("; server_no_context_takeover")
97+
}
98+
99+
// If we receive a valid value for server_max_window_bits, we accept it and return if
100+
// in the response. If we receive an invalid value, we default to 15 and return the
101+
// same in the response. If we receive no value, we ignore this header.
102+
if parameter.hasPrefix("server_max_window_bits") {
103+
let maxWindowBits = parameter.components(separatedBy: "=")
104+
guard maxWindowBits.count == 2 else { continue }
105+
guard let mwBits = Int32(maxWindowBits[1]) else { continue }
106+
if mwBits >= 8 && mwBits <= 15 {
107+
// We received a valid value. However there's a special case here:
108+
//
109+
// There's an open zlib issue which does not set the window size
110+
// to 256 (windowBits=8). For windowBits=8, zlib silently changes the
111+
// value to 9. However, this apparent hack works only with zlib streams.
112+
// WebSockets use raw deflate streams. For raw deflate streams, zlib has been
113+
// patched to ignore the windowBits value 8.
114+
// More details here: https://github.com/madler/zlib/issues/171
115+
//
116+
// So, if the server requested for server_max_window_bits=8, we are
117+
// going to use server_max_window_bits=9 instead and notify this in
118+
// our negotiation response too.
119+
if mwBits == 8 {
120+
response.append("; server_max_window_bits=9")
121+
} else {
122+
response.append("; \(parameter)")
123+
}
124+
} else {
125+
// we received an invalid value
126+
response.append("; server_max_window_bits=15")
127+
}
89128
}
90129
}
91-
//TODO: remove this after we have implemented context takeover
92-
response.append("; server_no_context_takeover")
93-
response.append("; client_no_context_takeover")
94130
return response
95131
}
96132
}

Sources/KituraWebSocket/PermessageDeflateCompressor.swift

+31-13
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ class PermessageDeflateCompressor : ChannelOutboundHandler {
4444
// The zlib stream
4545
private var stream: z_stream = z_stream()
4646

47+
// Initialize the z_stream only once if context takeover is enabled
48+
private var streamInitialized = false
49+
4750
// PermessageDeflateCompressor is an outbound handler, this function gets called when a frame is written to the channel by WebSocketConnection.
4851
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
4952
var frame = unwrapOutboundIn(data)
@@ -77,19 +80,38 @@ class PermessageDeflateCompressor : ChannelOutboundHandler {
7780
_ = context.writeAndFlush(self.wrapOutboundOut(deflatedFrame))
7881
}
7982

83+
func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
84+
// PermessageDeflateCompressor is an outbound handler. If the underlying
85+
// WebSocketConnection decides to close the connection, the close message
86+
// needs to be intercepted and the deflater closed while we're using context takeover.
87+
if noContextTakeOver == false {
88+
deflateEnd(&stream)
89+
}
90+
context.close(mode: mode, promise: promise)
91+
}
92+
8093
func deflatePayload(in buffer: ByteBuffer, allocator: ByteBufferAllocator, dropFourTrailingOctets: Bool = false) -> ByteBuffer {
8194
// Initialize the deflater as per https://www.zlib.net/zlib_how.html
82-
stream.zalloc = nil
83-
stream.zfree = nil
84-
stream.opaque = nil
85-
86-
let rc = deflateInit2_(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, self.maxWindowBits, 8,
87-
Z_DEFAULT_STRATEGY, ZLIB_VERSION, Int32(MemoryLayout<z_stream>.size))
88-
precondition(rc == Z_OK, "Unexpected return from zlib init: \(rc)")
95+
if noContextTakeOver || streamInitialized == false {
96+
stream.zalloc = nil
97+
stream.zfree = nil
98+
stream.opaque = nil
99+
stream.next_in = nil
100+
stream.avail_in = 0
101+
// The zlib manual asks us to provide a negative windowBits value for raw deflate
102+
let rc = deflateInit2_(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, -self.maxWindowBits, 8,
103+
Z_DEFAULT_STRATEGY, ZLIB_VERSION, Int32(MemoryLayout<z_stream>.size))
104+
precondition(rc == Z_OK, "Unexpected return from zlib init: \(rc)")
105+
self.streamInitialized = true
106+
}
89107

90108
defer {
91-
// Deinitialize the deflater before returning
92-
deflateEnd(&stream)
109+
if noContextTakeOver {
110+
// We aren't doing a context takeover.
111+
// This means the deflater is to be used on a per-message basis.
112+
// So, we deinitialize the deflater before returning.
113+
deflateEnd(&stream)
114+
}
93115
}
94116

95117
// Deflate/compress the payload
@@ -114,10 +136,6 @@ class PermessageDeflateCompressor : ChannelOutboundHandler {
114136
precondition(inputBuffer.readableBytes == 0)
115137
precondition(outputBuffer.readableBytes > 0)
116138

117-
// Remove the 0x78 0x9C zlib header added by zlib
118-
_ = outputBuffer.readBytes(length: 2)
119-
outputBuffer.discardReadBytes()
120-
121139
// Ignore the 0, 0, 0xff, 0xff trailer added by zlib
122140
if dropFourTrailingOctets {
123141
outputBuffer = outputBuffer.getSlice(at: 0, length: outputBuffer.readableBytes-4) ?? outputBuffer

Sources/KituraWebSocket/PermessageDeflateDecompressor.swift

+22-9
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class PermessageDeflateDecompressor : ChannelInboundHandler {
3636
// The zlib stream
3737
private var stream: z_stream = z_stream()
3838

39+
private var streamInitialized = false
40+
3941
// A buffer to accumulate payload across multiple frames
4042
var payload: ByteBuffer?
4143

@@ -56,6 +58,12 @@ class PermessageDeflateDecompressor : ChannelInboundHandler {
5658
let frame = unwrapInboundIn(data)
5759
// We should either have a data frame with rsv1 set, or a continuation frame of a compressed message. There's nothing to do otherwise.
5860
guard frame.isCompressedDataFrame || (frame.isContinuationFrame && self.receivingCompressedMessage) else {
61+
// If we are using context takeover, this is a good time to free the zstream!
62+
if streamInitialized && frame.opcode == .connectionClose && !noContextTakeOver {
63+
deflateEnd(&stream)
64+
streamInitialized = false
65+
}
66+
5967
context.fireChannelRead(self.wrapInboundOut(frame))
6068
return
6169
}
@@ -90,17 +98,22 @@ class PermessageDeflateDecompressor : ChannelInboundHandler {
9098

9199
func inflatePayload(in buffer: ByteBuffer, allocator: ByteBufferAllocator) -> ByteBuffer {
92100
// Initialize the inflater as per https://www.zlib.net/zlib_how.html
93-
stream.zalloc = nil
94-
stream.zfree = nil
95-
stream.opaque = nil
96-
stream.avail_in = 0
97-
stream.next_in = nil
98-
let rc = inflateInit2_(&stream, -self.maxWindowBits, ZLIB_VERSION, Int32(MemoryLayout<z_stream>.size))
99-
precondition(rc == Z_OK, "Unexpected return from zlib init: \(rc)")
101+
if noContextTakeOver || streamInitialized == false {
102+
stream.zalloc = nil
103+
stream.zfree = nil
104+
stream.opaque = nil
105+
stream.avail_in = 0
106+
stream.next_in = nil
107+
let rc = inflateInit2_(&stream, -self.maxWindowBits, ZLIB_VERSION, Int32(MemoryLayout<z_stream>.size))
108+
precondition(rc == Z_OK, "Unexpected return from zlib init: \(rc)")
109+
self.streamInitialized = true
110+
}
100111

101112
defer {
102-
// Deinitialize before returning
103-
inflateEnd(&stream)
113+
if noContextTakeOver {
114+
// Deinitialize before returning
115+
inflateEnd(&stream)
116+
}
104117
}
105118

106119
// Inflate/decompress the payload

Sources/KituraWebSocket/WebSocketConnection.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,10 @@ extension WebSocketConnection {
335335
let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data)
336336
let promise = context.eventLoop.makePromise(of: Void.self)
337337
context.writeAndFlush(self.wrapOutboundOut(frame), promise: promise)
338-
promise.futureResult.whenComplete { _ in
339-
if hard {
340-
_ = context.close(mode: .output)
341-
}
338+
if hard {
339+
promise.futureResult.flatMap { _ in
340+
context.close(mode: .output)
341+
}.whenComplete { _ in }
342342
}
343343
awaitClose = true
344344
}

Tests/KituraWebSocketTests/ComplexTests.swift

+35-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ class ComplexTests: KituraTest {
2929
("testPingBetweenBinaryFrames", testPingBetweenBinaryFrames),
3030
("testPingBetweenTextFrames", testPingBetweenTextFrames),
3131
("testTextShortAndMediumFrames", testTextShortAndMediumFrames),
32-
("testTextTwoShortFrames", testTextTwoShortFrames)
32+
("testTextTwoShortFrames", testTextTwoShortFrames),
33+
("testTwoMessagesWithContextTakeover", testTwoMessagesWithContextTakeover),
34+
("testTwoMessagesWithClientContextTakeover", testTwoMessagesWithClientContextTakeover),
35+
("testTwoMessagesWithServerContextTakeover", testTwoMessagesWithServerContextTakeover),
36+
("testTwoMessagesWithNoContextTakeover", testTwoMessagesWithNoContextTakeover),
3337
]
3438
}
3539

@@ -58,7 +62,7 @@ class ComplexTests: KituraTest {
5862
self.performTest(framesToSend: [(false, self.opcodeBinary, shortBinaryPayload), (true, self.opcodeContinuation, mediumBinaryPayload)],
5963
expectedFrames: [(true, self.opcodeBinary, expectedBinaryPayload)],
6064
expectation: expectation, negotiateCompression: true, compressed: true)
61-
}, { expectation in
65+
}, { expectation in
6266
self.performTest(framesToSend: [(false, self.opcodeBinary, shortBinaryPayload), (true, self.opcodeContinuation, mediumBinaryPayload)],
6367
expectedFrames: [(true, self.opcodeBinary, expectedBinaryPayload)],
6468
expectation: expectation, negotiateCompression: true, compressed: false)
@@ -183,4 +187,33 @@ class ComplexTests: KituraTest {
183187
expectation: expectation, negotiateCompression: true, compressed: false)
184188
})
185189
}
190+
191+
func testTwoMessages(contextTakeover: ContextTakeover = .both) {
192+
register(closeReason: .noReasonCodeSent)
193+
194+
let text = "RFC7692 specifies a framework for adding compression functionality to the WebSocket Protocol"
195+
let textPayload = self.payload(text: text)
196+
197+
performServerTest(asyncTasks: { expectation in
198+
self.performTest(framesToSend: [(true, self.opcodeText, textPayload), (true, self.opcodeText, textPayload)],
199+
expectedFrames: [(true, self.opcodeText, textPayload), (true, self.opcodeText, textPayload)],
200+
expectation: expectation, negotiateCompression: true, compressed: true, contextTakeover: contextTakeover)
201+
})
202+
}
203+
204+
func testTwoMessagesWithContextTakeover() {
205+
testTwoMessages(contextTakeover: .both)
206+
}
207+
208+
func testTwoMessagesWithClientContextTakeover() {
209+
testTwoMessages(contextTakeover: .client)
210+
}
211+
212+
func testTwoMessagesWithServerContextTakeover() {
213+
testTwoMessages(contextTakeover: .server)
214+
}
215+
216+
func testTwoMessagesWithNoContextTakeover() {
217+
testTwoMessages(contextTakeover: .none)
218+
}
186219
}

Tests/KituraWebSocketTests/KituraTest+Frames.swift

+7-4
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ extension KituraTest {
6565
}
6666

6767
//Sometimes, we may have a non-final frame as the last frame
68-
func sendFrame(final: Bool, withOpcode: Int, withMasking: Bool=true, withPayload: NSData, on channel: Channel, lastFrame: Bool = false, compressed: Bool = false) {
68+
func sendFrame(final: Bool, withOpcode: Int, withMasking: Bool=true, withPayload: NSData, on channel: Channel, lastFrame: Bool = false, compressed: Bool = false, contextTakeover: ContextTakeover? = nil) {
6969
var buffer = channel.allocator.buffer(capacity: 8)
7070
var payloadLength = withPayload.length
7171

@@ -76,7 +76,7 @@ extension KituraTest {
7676
}
7777

7878
if compressed {
79-
payloadBuffer = PermessageDeflateCompressor().deflatePayload(in: payloadBuffer, allocator: ByteBufferAllocator(), dropFourTrailingOctets: final)
79+
payloadBuffer = self.compressor.deflatePayload(in: payloadBuffer, allocator: ByteBufferAllocator(), dropFourTrailingOctets: final)
8080
payloadLength = payloadBuffer.readableBytes
8181
}
8282

@@ -192,11 +192,14 @@ class WebSocketClientHandler: ChannelInboundHandler {
192192

193193
var compressed: Bool = false
194194

195-
init(expectedFrames: [(Bool, Int, NSData)], expectation: XCTestExpectation, compressed: Bool = false) {
195+
var decompressor: PermessageDeflateDecompressor
196+
197+
init(expectedFrames: [(Bool, Int, NSData)], expectation: XCTestExpectation, compressed: Bool = false, decompressor: PermessageDeflateDecompressor) {
196198
self.numberOfFramesExpected = expectedFrames.count
197199
self.expectedFrames = expectedFrames
198200
self.expectation = expectation
199201
self.compressed = compressed
202+
self.decompressor = decompressor
200203
}
201204

202205
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
@@ -225,7 +228,7 @@ class WebSocketClientHandler: ChannelInboundHandler {
225228
currentFramePayload += [0, 0, 0xff, 0xff]
226229
var payloadBuffer = ByteBufferAllocator().buffer(capacity: 8)
227230
payloadBuffer.writeBytes(currentFramePayload)
228-
let inflatedBuffer = PermessageDeflateDecompressor().inflatePayload(in: payloadBuffer, allocator: ByteBufferAllocator())
231+
let inflatedBuffer = self.decompressor.inflatePayload(in: payloadBuffer, allocator: ByteBufferAllocator())
229232
currentFramePayload = inflatedBuffer.getBytes(at: 0, length: inflatedBuffer.readableBytes) ?? []
230233
}
231234
let currentFramePayloadPtr = UnsafeBufferPointer(start: &currentFramePayload, count: currentFramePayload.count)

0 commit comments

Comments
 (0)