1
1
package main
2
2
3
3
import (
4
+ "errors"
4
5
"flag"
5
6
"fmt"
6
7
"io"
@@ -199,6 +200,203 @@ func isNormalError(err error) bool {
199
200
return false
200
201
}
201
202
203
+ type fragmentedReader struct {
204
+ io.Reader
205
+ header [5 ]byte
206
+ buffer []byte
207
+ handshake []byte
208
+ handshakeErr error
209
+ processed bool
210
+ }
211
+
212
+ var errNotHandshakeRecord = errors .New ("not a handshake record" )
213
+
214
+ func (c * fragmentedReader ) readHandshakeRecord () ([]byte , error ) {
215
+ n , err := io .ReadFull (c .Reader , c .header [:5 ])
216
+ c .buffer = append (c .buffer , c .header [:n ]... )
217
+ if err != nil {
218
+ if err == io .ErrUnexpectedEOF {
219
+ err = io .EOF
220
+ }
221
+ return nil , err
222
+ }
223
+ // log.Printf("readHandshakeRecord: header %q", c.header[:5])
224
+
225
+ // TLS 1.0 handshake record
226
+ if c .header [0 ] == 0x16 && c .header [1 ] == 0x03 && c .header [2 ] == 0x01 {
227
+ n = int (c .header [3 ])<< 8 | int (c .header [4 ])
228
+ b := make ([]byte , n )
229
+ n , err = io .ReadFull (c .Reader , b )
230
+ c .buffer = append (c .buffer , b [:n ]... )
231
+ if err != nil {
232
+ if err == io .ErrUnexpectedEOF {
233
+ err = io .EOF
234
+ }
235
+ return nil , err
236
+ }
237
+
238
+ // log.Printf("readHandshakeRecord: data %q", b)
239
+ return b , nil
240
+ }
241
+
242
+ // not a handshake record
243
+ return nil , errNotHandshakeRecord
244
+ }
245
+
246
+ func (c * fragmentedReader ) readHandshakeBytes (n int ) error {
247
+ for len (c .handshake ) < n {
248
+ b , err := c .readHandshakeRecord ()
249
+ if err != nil {
250
+ return err
251
+ }
252
+ c .handshake = append (c .handshake , b ... )
253
+ }
254
+ return nil
255
+ }
256
+
257
+ func (c * fragmentedReader ) appendHandshakeRecord (b []byte ) {
258
+ for len (b ) > 0 {
259
+ n := len (b )
260
+ if n > 65535 {
261
+ n = 65535
262
+ }
263
+ c .buffer = append (c .buffer , 0x16 , 0x03 , 0x01 , byte (n >> 8 ), byte (n ))
264
+ c .buffer = append (c .buffer , b [:n ]... )
265
+ b = b [n :]
266
+ }
267
+ }
268
+
269
+ func findSnameExt (b []byte ) (int , int ) {
270
+ pos := 0
271
+ for len (b ) >= 4 {
272
+ n := int (b [2 ])<< 8 | int (b [3 ])
273
+ if ! (4 + n <= len (b )) {
274
+ break
275
+ }
276
+ if b [0 ] == 0 && b [1 ] == 0 {
277
+ return pos + 4 , n
278
+ }
279
+ b = b [4 + n :]
280
+ pos += 4 + n
281
+ }
282
+ return - 1 , - 1
283
+ }
284
+
285
+ func (c * fragmentedReader ) processClientHello () error {
286
+ err := c .readHandshakeBytes (4 )
287
+ if err != nil {
288
+ log .Printf ("failed to read 4 bytes (client hello header): %s" , err )
289
+ if err == errNotHandshakeRecord {
290
+ err = nil
291
+ }
292
+ return err
293
+ }
294
+ if c .handshake [0 ] != 0x01 {
295
+ // expected client hello message
296
+ return nil
297
+ }
298
+ n := int (c .handshake [1 ])<< 16 | int (c .handshake [2 ])<< 8 | int (c .handshake [3 ])
299
+ err = c .readHandshakeBytes (4 + n )
300
+ if err != nil {
301
+ log .Printf ("failed to read %d bytes (client hello data): %s" , n , err )
302
+ if err == errNotHandshakeRecord {
303
+ err = nil
304
+ }
305
+ return err
306
+ }
307
+ pos := 4
308
+ end := 4 + n
309
+ if ! (pos + 2 <= end ) || c .handshake [pos ] != 0x03 || c .handshake [pos + 1 ] != 0x03 {
310
+ // expected TLS 1.2 outer layer
311
+ return nil
312
+ }
313
+ pos += 2 + 32
314
+ // skip session id
315
+ if ! (pos + 1 <= end ) {
316
+ return nil
317
+ }
318
+ k := int (c .handshake [pos ])
319
+ pos += 1 + k
320
+ // skip cipher suites
321
+ if ! (pos + 2 <= end ) {
322
+ return nil
323
+ }
324
+ k = int (c .handshake [pos ])<< 8 | int (c .handshake [pos + 1 ])
325
+ pos += 2 + k
326
+ // skip compression methods
327
+ if ! (pos + 1 <= end ) {
328
+ return nil
329
+ }
330
+ k = int (c .handshake [pos ])
331
+ pos += 1 + k
332
+ // extensions
333
+ if ! (pos + 2 <= end ) {
334
+ return nil
335
+ }
336
+ extSize := int (c .handshake [pos ])<< 8 | int (c .handshake [pos + 1 ])
337
+ if extSize < 4 {
338
+ return nil
339
+ }
340
+ pos += 2
341
+ extStart := pos
342
+ extEnd := extStart + extSize
343
+ if ! (extEnd <= end ) {
344
+ return nil
345
+ }
346
+ ext := c .handshake [extStart :extEnd ]
347
+
348
+ // log.Printf("Full handshake: %x", c.handshake)
349
+ // log.Printf("Found extensions: %q", ext)
350
+
351
+ snameStart , snameSize := findSnameExt (ext )
352
+ if snameStart >= 0 {
353
+ // we found a server name extension
354
+ // let's repackage it into small fragmented records
355
+ snameStart += extStart
356
+ snameEnd := snameStart + snameSize
357
+ log .Printf ("Fragmenting sname: %q" , c .handshake [snameStart :snameEnd ])
358
+ pos = snameStart - 3 // we want to fragment the 00 00 tag as well
359
+ // log.Printf("Original buffer: %q", c.buffer)
360
+ c .buffer = c .buffer [:0 ]
361
+ c .appendHandshakeRecord (c .handshake [:pos ])
362
+ for pos + 2 < snameEnd {
363
+ c .appendHandshakeRecord (c .handshake [pos : pos + 2 ])
364
+ pos += 2
365
+ }
366
+ c .appendHandshakeRecord (c .handshake [pos :])
367
+ // log.Printf("Final buffer: %q", c.buffer)
368
+ }
369
+
370
+ return nil
371
+ }
372
+
373
+ func (c * fragmentedReader ) Read (b []byte ) (int , error ) {
374
+ if ! c .processed {
375
+ err := c .processClientHello ()
376
+ if err != nil {
377
+ c .handshakeErr = err
378
+ }
379
+ c .processed = true
380
+ }
381
+ if len (c .buffer ) > 0 {
382
+ n := len (c .buffer )
383
+ if n > len (b ) {
384
+ n = len (b )
385
+ }
386
+ copy (b [:n ], c .buffer [:n ])
387
+ c .buffer = c .buffer [n :]
388
+ if len (c .buffer ) == 0 {
389
+ c .buffer = nil
390
+ }
391
+ return n , nil
392
+ }
393
+ if c .handshakeErr != nil {
394
+ return 0 , c .handshakeErr
395
+ }
396
+ n , err := c .Reader .Read (b )
397
+ return n , err
398
+ }
399
+
202
400
func (p * SecureReverseProxy ) ServeHTTP (rw http.ResponseWriter , req * http.Request ) {
203
401
if req .Host == "localhost" || strings .HasPrefix (req .Host , "localhost:" ) {
204
402
rw .WriteHeader (200 )
@@ -228,6 +426,9 @@ func (p *SecureReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request
228
426
if actions & actionDirect != 0 {
229
427
prefix = "(direct) "
230
428
}
429
+ if actions & actionFragment != 0 {
430
+ prefix += "(fragmented) "
431
+ }
231
432
suffix := ""
232
433
if actions & actionBlock != 0 {
233
434
suffix = " (blocked)"
@@ -281,11 +482,11 @@ func (p *SecureReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request
281
482
io .WriteString (b , "HTTP/1.1 200 OK\r \n \r \n " )
282
483
b .Flush () // this is the last write into b
283
484
284
- // write side runs it its own goroutine
485
+ // write side runs in its own goroutine
285
486
go func () {
286
487
defer local .Close ()
287
488
defer remote .Close ()
288
- var buffer [65536 ]byte
489
+ var buffer [1024 ]byte
289
490
done := false
290
491
for ! done {
291
492
n , err := remote .Read (buffer [:])
@@ -307,11 +508,18 @@ func (p *SecureReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request
307
508
}
308
509
}()
309
510
511
+ var r io.Reader = b
512
+ if actions & actionFragment != 0 {
513
+ r = & fragmentedReader {
514
+ Reader : r ,
515
+ }
516
+ }
517
+
310
518
// read side runs here, first we grab what we have in in b
311
- var buffer [65536 ]byte
519
+ var buffer [1024 ]byte
312
520
done := false
313
521
for ! done {
314
- n , err := b .Read (buffer [:])
522
+ n , err := r .Read (buffer [:])
315
523
if n > 0 {
316
524
_ , werr := remote .Write (buffer [:n ])
317
525
if werr != nil {
0 commit comments