Skip to content

Commit 932b71a

Browse files
mengelbartSean-Der
authored andcommitted
Implement draft-ietf-tsvwg-sctp-zero-checksum-01
1 parent 2927025 commit 932b71a

10 files changed

+325
-71
lines changed

association.go

+73-14
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ type Association struct {
177177
cumulativeTSNAckPoint uint32
178178
advancedPeerTSNAckPoint uint32
179179
useForwardTSN bool
180+
useZeroChecksum bool
181+
requestZeroChecksum bool
180182

181183
// Congestion control parameters
182184
maxReceiveBufferSize uint32
@@ -233,6 +235,7 @@ type Config struct {
233235
NetConn net.Conn
234236
MaxReceiveBufferSize uint32
235237
MaxMessageSize uint32
238+
EnableZeroChecksum bool
236239
LoggerFactory logging.LoggerFactory
237240
}
238241

@@ -320,6 +323,7 @@ func createAssociation(config Config) *Association {
320323
handshakeCompletedCh: make(chan error),
321324
cumulativeTSNAckPoint: tsn - 1,
322325
advancedPeerTSNAckPoint: tsn - 1,
326+
requestZeroChecksum: config.EnableZeroChecksum,
323327
silentError: ErrSilentlyDiscard,
324328
stats: &associationStats{},
325329
log: config.LoggerFactory.NewLogger("sctp"),
@@ -362,6 +366,11 @@ func (a *Association) init(isClient bool) {
362366
init.initiateTag = a.myVerificationTag
363367
init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize
364368
setSupportedExtensions(&init.chunkInitCommon)
369+
370+
if a.requestZeroChecksum {
371+
init.params = append(init.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
372+
}
373+
365374
a.storedInit = init
366375

367376
err := a.sendInit()
@@ -618,10 +627,45 @@ func (a *Association) unregisterStream(s *Stream, err error) {
618627
s.readNotifier.Broadcast()
619628
}
620629

630+
func chunkMandatoryChecksum(cc []chunk) bool {
631+
for _, c := range cc {
632+
switch c.(type) {
633+
case *chunkInit, *chunkInitAck, *chunkCookieEcho:
634+
return true
635+
}
636+
}
637+
return false
638+
}
639+
640+
func (a *Association) marshalPacket(p *packet) ([]byte, error) {
641+
return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks))
642+
}
643+
644+
func (a *Association) unmarshalPacket(raw []byte) (*packet, error) {
645+
p := &packet{}
646+
if !a.useZeroChecksum {
647+
if err := p.unmarshal(true, raw); err != nil {
648+
return nil, err
649+
}
650+
return p, nil
651+
}
652+
653+
if err := p.unmarshal(false, raw); err != nil {
654+
return nil, err
655+
}
656+
if chunkMandatoryChecksum(p.chunks) {
657+
if err := p.unmarshal(true, raw); err != nil {
658+
return nil, err
659+
}
660+
}
661+
662+
return p, nil
663+
}
664+
621665
// handleInbound parses incoming raw packets
622666
func (a *Association) handleInbound(raw []byte) error {
623-
p := &packet{}
624-
if err := p.unmarshal(raw); err != nil {
667+
p, err := a.unmarshalPacket(raw)
668+
if err != nil {
625669
a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err)
626670
return nil
627671
}
@@ -647,7 +691,7 @@ func (a *Association) handleInbound(raw []byte) error {
647691
// The caller should hold the lock
648692
func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte {
649693
for _, p := range a.getDataPacketsToRetransmit() {
650-
raw, err := p.marshal()
694+
raw, err := a.marshalPacket(p)
651695
if err != nil {
652696
a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name)
653697
continue
@@ -668,7 +712,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
668712
a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name)
669713
a.t3RTX.start(a.rtoMgr.getRTO())
670714
for _, p := range a.bundleDataChunksIntoPackets(chunks) {
671-
raw, err := p.marshal()
715+
raw, err := a.marshalPacket(p)
672716
if err != nil {
673717
a.log.Warnf("[%s] failed to serialize a DATA packet", a.name)
674718
continue
@@ -683,7 +727,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
683727
a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs))
684728
for _, c := range a.reconfigs {
685729
p := a.createPacket([]chunk{c})
686-
raw, err := p.marshal()
730+
raw, err := a.marshalPacket(p)
687731
if err != nil {
688732
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name)
689733
} else {
@@ -706,7 +750,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
706750
a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v",
707751
a.name, rsn, a.myNextTSN-1, sisToReset)
708752
p := a.createPacket([]chunk{c})
709-
raw, err := p.marshal()
753+
raw, err := a.marshalPacket(p)
710754
if err != nil {
711755
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name)
712756
} else {
@@ -769,7 +813,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
769813
}
770814

771815
if len(toFastRetrans) > 0 {
772-
raw, err := a.createPacket(toFastRetrans).marshal()
816+
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
773817
if err != nil {
774818
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
775819
} else {
@@ -787,7 +831,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
787831
a.ackState = ackStateIdle
788832
sack := a.createSelectiveAckChunk()
789833
a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
790-
raw, err := a.createPacket([]chunk{sack}).marshal()
834+
raw, err := a.marshalPacket(a.createPacket([]chunk{sack}))
791835
if err != nil {
792836
a.log.Warnf("[%s] failed to serialize a SACK packet", a.name)
793837
} else {
@@ -804,7 +848,7 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b
804848
a.willSendForwardTSN = false
805849
if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) {
806850
fwdtsn := a.createForwardTSN()
807-
raw, err := a.createPacket([]chunk{fwdtsn}).marshal()
851+
raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn}))
808852
if err != nil {
809853
a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name)
810854
} else {
@@ -827,7 +871,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
827871
cumulativeTSNAck: a.cumulativeTSNAckPoint,
828872
}
829873

830-
raw, err := a.createPacket([]chunk{shutdown}).marshal()
874+
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown}))
831875
if err != nil {
832876
a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name)
833877
} else {
@@ -839,7 +883,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
839883

840884
shutdownAck := &chunkShutdownAck{}
841885

842-
raw, err := a.createPacket([]chunk{shutdownAck}).marshal()
886+
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck}))
843887
if err != nil {
844888
a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name)
845889
} else {
@@ -851,7 +895,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
851895

852896
shutdownComplete := &chunkShutdownComplete{}
853897

854-
raw, err := a.createPacket([]chunk{shutdownComplete}).marshal()
898+
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete}))
855899
if err != nil {
856900
a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name)
857901
} else {
@@ -875,7 +919,7 @@ func (a *Association) gatherAbortPacket() ([]byte, error) {
875919
abort.errorCauses = []errorCause{cause}
876920
}
877921

878-
raw, err := a.createPacket([]chunk{abort}).marshal()
922+
raw, err := a.marshalPacket(a.createPacket([]chunk{abort}))
879923

880924
return raw, err
881925
}
@@ -900,7 +944,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) {
900944

901945
if a.controlQueue.size() > 0 {
902946
for _, p := range a.controlQueue.popAll() {
903-
raw, err := p.marshal()
947+
raw, err := a.marshalPacket(p)
904948
if err != nil {
905949
a.log.Warnf("[%s] failed to serialize a control packet", a.name)
906950
continue
@@ -1092,6 +1136,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
10921136
// subtracting one from it.
10931137
a.peerLastTSN = i.initialTSN - 1
10941138

1139+
peerHasZeroChecksum := false
10951140
for _, param := range i.params {
10961141
switch v := param.(type) { // nolint:gocritic
10971142
case *paramSupportedExtensions:
@@ -1101,8 +1146,11 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
11011146
a.useForwardTSN = true
11021147
}
11031148
}
1149+
case *paramZeroChecksumAcceptable:
1150+
peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod
11041151
}
11051152
}
1153+
11061154
if !a.useForwardTSN {
11071155
a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name)
11081156
}
@@ -1129,6 +1177,12 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
11291177

11301178
initAck.params = []param{a.myCookie}
11311179

1180+
if peerHasZeroChecksum {
1181+
initAck.params = append(initAck.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
1182+
a.useZeroChecksum = true
1183+
}
1184+
a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum)
1185+
11321186
setSupportedExtensions(&initAck.chunkInitCommon)
11331187

11341188
outbound.chunks = []chunk{initAck}
@@ -1186,8 +1240,13 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
11861240
a.useForwardTSN = true
11871241
}
11881242
}
1243+
case *paramZeroChecksumAcceptable:
1244+
a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod
11891245
}
11901246
}
1247+
1248+
a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum)
1249+
11911250
if !a.useForwardTSN {
11921251
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
11931252
}

association_test.go

+87-3
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ func TestAssocT1CookieTimer(t *testing.T) {
16221622
// Drop all COOKIE-ECHO
16231623
br.Filter(0, func(raw []byte) bool {
16241624
p := &packet{}
1625-
err := p.unmarshal(raw)
1625+
err := p.unmarshal(true, raw)
16261626
if !assert.Nil(t, err, "failed to parse packet") {
16271627
return false // drop
16281628
}
@@ -2285,7 +2285,7 @@ func TestAssocAbort(t *testing.T) {
22852285
errorCauseHeader: errorCauseHeader{code: protocolViolation},
22862286
}},
22872287
}
2288-
packet, err := a0.createPacket([]chunk{abort}).marshal()
2288+
packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort}))
22892289
assert.NoError(t, err)
22902290

22912291
_, _, err = establishSessionPair(br, a0, a1, si)
@@ -2964,7 +2964,7 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) {
29642964
}()
29652965
}
29662966

2967-
packet, err := testCase.inputPacket.marshal()
2967+
packet, err := a.marshalPacket(testCase.inputPacket)
29682968
assert.NoError(t, err)
29692969
_, err = charlieConn.Write(packet)
29702970
assert.NoError(t, err)
@@ -3072,3 +3072,87 @@ loop:
30723072
assert.Error(t, err1, "context canceled")
30733073
assert.Error(t, err2, "context canceled")
30743074
}
3075+
3076+
type customLogger struct {
3077+
expectZeroChecksum bool
3078+
t *testing.T
3079+
}
3080+
3081+
func (c customLogger) Trace(string) {}
3082+
func (c customLogger) Tracef(string, ...interface{}) {}
3083+
func (c customLogger) Debug(string) {}
3084+
func (c customLogger) Debugf(format string, args ...interface{}) {
3085+
if format == "[%s] useZeroChecksum=%t (on initAck)" {
3086+
assert.Equal(c.t, args[1], c.expectZeroChecksum)
3087+
}
3088+
}
3089+
func (c customLogger) Info(string) {}
3090+
func (c customLogger) Infof(string, ...interface{}) {}
3091+
func (c customLogger) Warn(string) {}
3092+
func (c customLogger) Warnf(string, ...interface{}) {}
3093+
func (c customLogger) Error(string) {}
3094+
func (c customLogger) Errorf(string, ...interface{}) {}
3095+
3096+
func (c customLogger) NewLogger(string) logging.LeveledLogger {
3097+
return c
3098+
}
3099+
3100+
func TestAssociation_ZeroChecksum(t *testing.T) {
3101+
checkGoroutineLeaks(t)
3102+
3103+
lim := test.TimeOut(time.Second * 10)
3104+
defer lim.Stop()
3105+
3106+
for _, testCase := range []struct {
3107+
clientZeroChecksum, serverZeroChecksum, expectChecksumEnabled bool
3108+
}{
3109+
{true, true, true},
3110+
{false, false, false},
3111+
{true, false, true},
3112+
{false, true, false},
3113+
} {
3114+
a1chan, a2chan := make(chan *Association), make(chan *Association)
3115+
3116+
udp1, udp2 := createUDPConnPair()
3117+
3118+
go func() {
3119+
a1, err := Client(Config{
3120+
NetConn: udp1,
3121+
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
3122+
EnableZeroChecksum: testCase.clientZeroChecksum,
3123+
})
3124+
assert.NoError(t, err)
3125+
a1chan <- a1
3126+
}()
3127+
3128+
go func() {
3129+
a2, err := Server(Config{
3130+
NetConn: udp2,
3131+
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
3132+
EnableZeroChecksum: testCase.serverZeroChecksum,
3133+
})
3134+
assert.NoError(t, err)
3135+
a2chan <- a2
3136+
}()
3137+
3138+
a1, a2 := <-a1chan, <-a2chan
3139+
3140+
writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString)
3141+
require.NoError(t, err)
3142+
3143+
readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString)
3144+
require.NoError(t, err)
3145+
3146+
testData := []byte("test")
3147+
_, err = writeStream.Write(testData)
3148+
require.NoError(t, err)
3149+
3150+
buf := make([]byte, len(testData))
3151+
_, err = readStream.Read(buf)
3152+
assert.NoError(t, err)
3153+
assert.Equal(t, testData, buf)
3154+
3155+
require.NoError(t, a1.Close())
3156+
require.NoError(t, a2.Close())
3157+
}
3158+
}

0 commit comments

Comments
 (0)