Skip to content

Commit c87bbac

Browse files
committed
Refactoring connect for windows
1 parent 30263ee commit c87bbac

File tree

2 files changed

+50
-39
lines changed

2 files changed

+50
-39
lines changed

chronos/transports/stream.nim

+45-38
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,10 @@ proc bindSocket*(sock: AsyncFD, localAddress: TransportAddress, reuseAddr = true
301301
# Setting SO_REUSEADDR option we are able to reuse ports using the 0.0.0.0 address (or equivalent)
302302
setSockOptInt(SocketHandle(sock), SOL_SOCKET, SO_REUSEADDR, 1)
303303

304-
var raddress =
305-
when defined(windows):
306-
windowsAnyAddressFix(localAddress)
307-
else:
308-
localAddress
309304
var
310305
localAddr: Sockaddr_storage
311306
localAddrLen: SockLen
312-
raddress.toSAddr(localAddr, localAddrLen)
307+
localAddress.toSAddr(localAddr, localAddrLen)
313308
if bindSocket(SocketHandle(sock), cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
314309
raiseTransportOsError(osLastError())
315310

@@ -718,6 +713,13 @@ elif defined(windows):
718713
sizeof(saddr).SockLen) != 0'i32:
719714
result = false
720715

716+
proc isDomainSet(sock: AsyncFD): bool =
717+
try:
718+
discard getSockDomain(SocketHandle(sock))
719+
true
720+
except CatchableError as ex:
721+
false
722+
721723
proc connect*(sock: AsyncFD,
722724
address: TransportAddress,
723725
bufferSize = DefaultStreamBufferSize,
@@ -734,7 +736,6 @@ elif defined(windows):
734736
var
735737
saddr: Sockaddr_storage
736738
slen: SockLen
737-
sock: AsyncFD
738739
povl: RefCustomOverlapped
739740

740741
var raddress = windowsAnyAddressFix(address)
@@ -745,7 +746,7 @@ elif defined(windows):
745746
retFuture.fail(getTransportOsError(osLastError()))
746747
return retFuture
747748

748-
if not(bindToDomain(sock, raddress.getDomain())):
749+
if not isDomainSet(sock) and not(bindToDomain(sock, raddress.getDomain())):
749750
let err = wsaGetLastError()
750751
sock.closeSocket()
751752
retFuture.fail(getTransportOsError(err))
@@ -791,6 +792,32 @@ elif defined(windows):
791792

792793
retFuture.cancelCallback = cancel
793794

795+
else: #address.family == AddressFamily.Unix:
796+
retFuture.fail(newException(TransportAddressError, "Unsupported address family"))
797+
798+
return retFuture
799+
800+
proc connect*(address: TransportAddress,
801+
bufferSize = DefaultStreamBufferSize,
802+
child: StreamTransport = nil,
803+
flags: set[TransportFlags] = {}): Future[StreamTransport] =
804+
## Open new connection to remote peer with address ``address`` and create
805+
## new transport object ``StreamTransport`` for established connection.
806+
## ``bufferSize`` is size of internal buffer for transport.
807+
var retFuture = newFuture[StreamTransport]("stream.transport.connect")
808+
if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
809+
var raddress = windowsAnyAddressFix(address)
810+
try:
811+
let sock = createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP)
812+
let r = connect(sock, address, bufferSize, child, flags)
813+
proc cb(arg: pointer) =
814+
try:
815+
retFuture.complete(r.read)
816+
except CatchableError as exc:
817+
retFuture.fail(exc)
818+
r.addCallback(cb)
819+
except CatchableError as exc:
820+
retFuture.fail(exc)
794821
elif address.family == AddressFamily.Unix:
795822
## Unix domain socket emulation with Windows Named Pipes.
796823
var pipeHandle = INVALID_HANDLE_VALUE
@@ -807,45 +834,25 @@ elif defined(windows):
807834
if pipeHandle == INVALID_HANDLE_VALUE:
808835
let err = osLastError()
809836
if int32(err) == ERROR_PIPE_BUSY:
810-
discard setTimer(Moment.fromNow(50.milliseconds),
811-
pipeContinuation, nil)
837+
discard setTimer(Moment.fromNow(50.milliseconds), pipeContinuation, nil)
812838
else:
813839
retFuture.fail(getTransportOsError(err))
814840
else:
815-
try:
816-
register(AsyncFD(pipeHandle))
817-
except CatchableError as exc:
818-
retFuture.fail(exc)
819-
return
820-
821-
let transp = try: newStreamPipeTransport(AsyncFD(pipeHandle),
822-
bufferSize, child)
823-
except CatchableError as exc:
824-
retFuture.fail(exc)
825-
return
841+
let transp =
842+
try:
843+
register(AsyncFD(pipeHandle))
844+
newStreamPipeTransport(AsyncFD(pipeHandle), bufferSize, child)
845+
except CatchableError as exc:
846+
retFuture.fail(exc)
847+
return
826848
# Start tracking transport
827849
trackStream(transp)
828850
retFuture.complete(transp)
829851
pipeContinuation(nil)
830-
852+
else:
853+
retFuture.fail(newException(TransportAddressError, "Unsupported address family"))
831854
return retFuture
832855

833-
proc connect*(address: TransportAddress,
834-
bufferSize = DefaultStreamBufferSize,
835-
child: StreamTransport = nil,
836-
flags: set[TransportFlags] = {}): Future[StreamTransport] =
837-
## Open new connection to remote peer with address ``address`` and create
838-
## new transport object ``StreamTransport`` for established connection.
839-
## ``bufferSize`` is size of internal buffer for transport.
840-
var raddress = windowsAnyAddressFix(address)
841-
let sock =
842-
try: createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP)
843-
except CatchableError as exc:
844-
var retFuture = newFuture[StreamTransport]("stream.transport.connect")
845-
retFuture.fail(exc)
846-
return retFuture
847-
return connect(sock, raddress, bufferSize, child, flags)
848-
849856
proc createAcceptPipe(server: StreamServer) {.
850857
raises: [Defect, CatchableError].} =
851858
let pipeSuffix = $cast[cstring](addr server.local.address_un)

tests/teststream.nim

+5-1
Original file line numberDiff line numberDiff line change
@@ -1389,13 +1389,17 @@ suite "Stream Transport test suite":
13891389
expect TransportOsError:
13901390
bindSocket(sock3, ta, false)
13911391

1392+
waitFor transp1.closeWait()
1393+
waitFor transp2.closeWait()
1394+
sock3.closeSocket()
1395+
13921396
for server in servers:
13931397
server.stop()
13941398
waitFor server.closeWait()
13951399

13961400
waitFor transp1.closeWait()
13971401
waitFor transp2.closeWait()
1398-
sock2.closeSocket()
1402+
sock3.closeSocket()
13991403

14001404
test "Leaks test":
14011405
proc getTrackerLeaks(tracker: string): bool =

0 commit comments

Comments
 (0)