Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44011: New asyncio ssl implementation #31275

Merged
merged 3 commits into from
Feb 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
@@ -269,7 +269,7 @@ async def restore(self):
class Server(events.AbstractServer):

def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout):
ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
@@ -278,6 +278,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None

@@ -309,7 +310,8 @@ def _start_serving(self):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout)
self, self._backlog, self._ssl_handshake_timeout,
self._ssl_shutdown_timeout)

def get_loop(self):
return self._loop
@@ -463,6 +465,7 @@ def _make_ssl_transport(
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
@@ -965,6 +968,7 @@ async def create_connection(
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
@@ -1000,6 +1004,10 @@ async def create_connection(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
@@ -1075,7 +1083,8 @@ async def create_connection(

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
@@ -1087,7 +1096,8 @@ async def create_connection(
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

sock.setblocking(False)

@@ -1098,7 +1108,8 @@ async def _create_connection_transport(
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)

@@ -1189,7 +1200,8 @@ async def _sendfile_fallback(self, transp, file, offset, count):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
@@ -1212,6 +1224,7 @@ async def start_tls(self, transport, protocol, sslcontext, *,
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

# Pause early so that "ssl_protocol.data_received()" doesn't
@@ -1397,6 +1410,7 @@ async def create_server(
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
@@ -1420,6 +1434,10 @@ async def create_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and ssl is None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if host is not None or port is not None:
if sock is not None:
raise ValueError(
@@ -1492,7 +1510,8 @@ async def create_server(
sock.setblocking(False)

server = Server(self, sockets, protocol_factory,
ssl, backlog, ssl_handshake_timeout)
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
@@ -1506,17 +1525,23 @@ async def create_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')

if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
7 changes: 7 additions & 0 deletions Lib/asyncio/constants.py
Original file line number Diff line number Diff line change
@@ -15,10 +15,17 @@
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0

# Number of seconds to wait for SSL shutdown to complete
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT = 30.0

# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256

FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB

# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
19 changes: 16 additions & 3 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
@@ -303,6 +303,7 @@ async def create_connection(
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError

@@ -312,6 +313,7 @@ async def create_server(
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.
@@ -352,6 +354,10 @@ async def create_server(
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown procedure
before aborting the connection. Default is 30s.
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -370,7 +376,8 @@ async def sendfile(self, transport, file, offset=0, count=None,
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
@@ -382,13 +389,15 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
raise NotImplementedError

async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.
@@ -410,6 +419,9 @@ async def create_unix_server(
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for the SSL shutdown to finish (defaults to 30s).
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@@ -420,7 +432,8 @@ async def create_unix_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Handle an accepted connection.
This is used by servers that accept connections outside of
12 changes: 8 additions & 4 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
@@ -642,11 +642,13 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -812,7 +814,8 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

def loop(f=None):
try:
@@ -826,7 +829,8 @@ def loop(f=None):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
self._make_socket_transport(
conn, protocol,
31 changes: 20 additions & 11 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
@@ -70,11 +70,15 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout
)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@@ -146,15 +150,17 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)

def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
@@ -185,20 +191,22 @@ def _accept_connection(
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout)
backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)
self.create_task(accept)

async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
protocol = None
transport = None
try:
@@ -208,7 +216,8 @@ async def _accept_connection2(
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
1,059 changes: 622 additions & 437 deletions Lib/asyncio/sslproto.py

Large diffs are not rendered by default.

17 changes: 14 additions & 3 deletions Lib/asyncio/unix_events.py
Original file line number Diff line number Diff line change
@@ -229,7 +229,8 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
@@ -241,6 +242,9 @@ async def create_unix_connection(
if ssl_handshake_timeout is not None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if path is not None:
if sock is not None:
@@ -267,13 +271,15 @@ async def create_unix_connection(

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
return transport, protocol

async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
@@ -282,6 +288,10 @@ async def create_unix_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if path is not None:
if sock is not None:
raise ValueError(
@@ -328,7 +338,8 @@ async def create_unix_server(

sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory,
ssl, backlog, ssl_handshake_timeout)
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
21 changes: 14 additions & 7 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
@@ -1451,44 +1451,51 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
shutdown_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)

def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
@@ -1869,7 +1876,7 @@ def test_accept_connection_exception(self, m_log):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
MyProto, sock, None, None, mock.ANY, mock.ANY)
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)

def test_call_coroutine(self):
async def simple_coroutine():
38 changes: 0 additions & 38 deletions Lib/test/test_asyncio/test_selector_events.py
Original file line number Diff line number Diff line change
@@ -71,44 +71,6 @@ def test_make_socket_transport(self):

close_transport(transport)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self):
m = mock.Mock()
self.loop._add_reader = mock.Mock()
self.loop._add_reader._is_coroutine = False
self.loop._add_writer = mock.Mock()
self.loop._remove_reader = mock.Mock()
self.loop._remove_writer = mock.Mock()
waiter = self.loop.create_future()
with test_utils.disable_logger():
transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)

with self.assertRaisesRegex(RuntimeError,
r'SSL transport.*not.*initialized'):
transport.is_reading()

# execute the handshake while the logger is disabled
# to ignore SSL handshake failure
test_utils.run_briefly(self.loop)

self.assertTrue(transport.is_reading())
transport.pause_reading()
transport.pause_reading()
self.assertFalse(transport.is_reading())
transport.resume_reading()
transport.resume_reading()
self.assertTrue(transport.is_reading())

# Sanity check
class_name = transport.__class__.__name__
self.assertIn("ssl", class_name.lower())
self.assertIn("transport", class_name.lower())

transport.close()
# execute pending callbacks to close the socket transport
test_utils.run_briefly(self.loop)

@mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
1,721 changes: 1,721 additions & 0 deletions Lib/test/test_asyncio/test_ssl.py

Large diffs are not rendered by default.

35 changes: 20 additions & 15 deletions Lib/test/test_asyncio/test_sslproto.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
from test import support
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests

@@ -44,16 +43,13 @@ def ssl_protocol(self, *, waiter=None, proto=None):

def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
sslpipe = mock.Mock()
sslpipe.shutdown.return_value = b''
if do_handshake:
sslpipe.do_handshake.side_effect = do_handshake
else:
def mock_handshake(callback):
return []
sslpipe.do_handshake.side_effect = mock_handshake
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
ssl_proto.connection_made(transport)
sslobj = mock.Mock()
# emulate reading decompressed data
sslobj.read.side_effect = ssl.SSLWantReadError
if do_handshake is not None:
sslobj.do_handshake = do_handshake
ssl_proto._sslobj = sslobj
ssl_proto.connection_made(transport)
return transport

def test_handshake_timeout_zero(self):
@@ -75,7 +71,10 @@ def test_handshake_timeout_negative(self):
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(ssl_proto)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
@@ -100,7 +99,10 @@ def test_connection_lost(self):
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(ssl_proto)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@@ -110,7 +112,10 @@ def test_close_during_handshake(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)

transport = self.connection_made(ssl_proto)
transport = self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
test_utils.run_briefly(self.loop)

ssl_proto._app_transport.close()
@@ -143,7 +148,7 @@ def test_data_received_after_closing(self):
transp.close()

# should not raise
self.assertIsNone(ssl_proto.data_received(b'data'))
self.assertIsNone(ssl_proto.buffer_updated(5))

def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Reimplement SSL/TLS support in asyncio, borrow the implementation from
uvloop library.