@@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="",
53
53
connect_timeout = None , read_default_group = None ,
54
54
autocommit = False , echo = False ,
55
55
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56
- program_name = '' , server_public_key = None ):
56
+ program_name = '' , server_public_key = None , implicit_tls = False ):
57
57
"""See connections.Connection.__init__() for information about
58
58
defaults."""
59
59
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="",
66
66
read_default_group = read_default_group ,
67
67
autocommit = autocommit , echo = echo ,
68
68
local_infile = local_infile , loop = loop , ssl = ssl ,
69
- auth_plugin = auth_plugin , program_name = program_name )
69
+ auth_plugin = auth_plugin , program_name = program_name ,
70
+ implicit_tls = implicit_tls )
70
71
return _ConnectionContextManager (coro )
71
72
72
73
@@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="",
142
143
connect_timeout = None , read_default_group = None ,
143
144
autocommit = False , echo = False ,
144
145
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145
- program_name = '' , server_public_key = None ):
146
+ program_name = '' , server_public_key = None , implicit_tls = False ):
146
147
"""
147
148
Establish a connection to the MySQL database. Accepts several
148
149
arguments:
@@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="",
184
185
handshaking with MySQL. (omitted by default)
185
186
:param server_public_key: SHA256 authentication plugin public
186
187
key value.
188
+ :param implicit_tls: Establish TLS immediately, skipping non-TLS
189
+ preamble before upgrading to TLS.
190
+ (default: False)
187
191
:param loop: asyncio loop
188
192
"""
189
193
self ._loop = loop or asyncio .get_event_loop ()
@@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="",
218
222
self ._auth_plugin_used = ""
219
223
self ._secure = False
220
224
self .server_public_key = server_public_key
225
+ self ._implicit_tls = implicit_tls
221
226
self .salt = None
222
227
223
228
from . import __version__
@@ -241,7 +246,7 @@ def __init__(self, host="localhost", user=None, password="",
241
246
self .use_unicode = use_unicode
242
247
243
248
self ._ssl_context = ssl
244
- if ssl :
249
+ if ssl and not implicit_tls :
245
250
client_flag |= CLIENT .SSL
246
251
247
252
self ._encoding = charset_by_name (self ._charset ).encoding
@@ -536,7 +541,8 @@ async def _connect(self):
536
541
537
542
self ._next_seq_id = 0
538
543
539
- await self ._get_server_information ()
544
+ if not self ._implicit_tls :
545
+ await self ._get_server_information ()
540
546
await self ._request_authentication ()
541
547
542
548
self .connected_time = self ._loop .time ()
@@ -727,7 +733,8 @@ async def _execute_command(self, command, sql):
727
733
728
734
async def _request_authentication (self ):
729
735
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
730
- if int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
736
+ # FIXME: change this before merge
737
+ if self ._implicit_tls or int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
731
738
self .client_flag |= CLIENT .MULTI_RESULTS
732
739
733
740
if self .user is None :
@@ -737,8 +744,10 @@ async def _request_authentication(self):
737
744
data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
738
745
charset_id , b'' )
739
746
740
- if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
741
- self .write_packet (data_init )
747
+ if self ._ssl_context and \
748
+ (self ._implicit_tls or self .server_capabilities & CLIENT .SSL ):
749
+ if not self ._implicit_tls :
750
+ self .write_packet (data_init )
742
751
743
752
# Stop sending events to data_received
744
753
self ._writer .transport .pause_reading ()
@@ -760,6 +769,9 @@ async def _request_authentication(self):
760
769
server_hostname = self ._host
761
770
)
762
771
772
+ if self ._implicit_tls :
773
+ await self ._get_server_information ()
774
+
763
775
self ._secure = True
764
776
765
777
if isinstance (self .user , str ):
0 commit comments