9
9
import warnings
10
10
import configparser
11
11
import getpass
12
+ import ssl as ssllib
12
13
from functools import partial
13
14
14
15
from pymysql .charset import charset_by_name , charset_by_id
@@ -53,7 +54,7 @@ def connect(host="localhost", user=None, password="",
53
54
connect_timeout = None , read_default_group = None ,
54
55
autocommit = False , echo = False ,
55
56
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56
- program_name = '' , server_public_key = None ):
57
+ program_name = '' , server_public_key = None , implicit_tls = False ):
57
58
"""See connections.Connection.__init__() for information about
58
59
defaults."""
59
60
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
66
67
read_default_group = read_default_group ,
67
68
autocommit = autocommit , echo = echo ,
68
69
local_infile = local_infile , loop = loop , ssl = ssl ,
69
- auth_plugin = auth_plugin , program_name = program_name )
70
+ auth_plugin = auth_plugin , program_name = program_name ,
71
+ implicit_tls = implicit_tls )
70
72
return _ConnectionContextManager (coro )
71
73
72
74
@@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
142
144
connect_timeout = None , read_default_group = None ,
143
145
autocommit = False , echo = False ,
144
146
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145
- program_name = '' , server_public_key = None ):
147
+ program_name = '' , server_public_key = None , implicit_tls = False ):
146
148
"""
147
149
Establish a connection to the MySQL database. Accepts several
148
150
arguments:
@@ -184,6 +186,9 @@ def __init__(self, host="localhost", user=None, password="",
184
186
handshaking with MySQL. (omitted by default)
185
187
:param server_public_key: SHA256 authentication plugin public
186
188
key value.
189
+ :param implicit_tls: Establish TLS immediately, skipping non-TLS
190
+ preamble before upgrading to TLS.
191
+ (default: False)
187
192
:param loop: asyncio loop
188
193
"""
189
194
self ._loop = loop or asyncio .get_event_loop ()
@@ -218,6 +223,7 @@ def __init__(self, host="localhost", user=None, password="",
218
223
self ._auth_plugin_used = ""
219
224
self ._secure = False
220
225
self .server_public_key = server_public_key
226
+ self ._implicit_tls = implicit_tls
221
227
self .salt = None
222
228
223
229
from . import __version__
@@ -241,7 +247,10 @@ def __init__(self, host="localhost", user=None, password="",
241
247
self .use_unicode = use_unicode
242
248
243
249
self ._ssl_context = ssl
244
- if ssl :
250
+ # TLS is required when implicit_tls is True
251
+ if implicit_tls and not self ._ssl_context :
252
+ self ._ssl_context = ssllib .create_default_context ()
253
+ if ssl and not implicit_tls :
245
254
client_flag |= CLIENT .SSL
246
255
247
256
self ._encoding = charset_by_name (self ._charset ).encoding
@@ -536,7 +545,8 @@ async def _connect(self):
536
545
537
546
self ._next_seq_id = 0
538
547
539
- await self ._get_server_information ()
548
+ if not self ._implicit_tls :
549
+ await self ._get_server_information ()
540
550
await self ._request_authentication ()
541
551
542
552
self .connected_time = self ._loop .time ()
@@ -738,7 +748,8 @@ async def _execute_command(self, command, sql):
738
748
739
749
async def _request_authentication (self ):
740
750
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
741
- if int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
751
+ # FIXME: change this before merge
752
+ if self ._implicit_tls or int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
742
753
self .client_flag |= CLIENT .MULTI_RESULTS
743
754
744
755
if self .user is None :
@@ -748,8 +759,10 @@ async def _request_authentication(self):
748
759
data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
749
760
charset_id , b'' )
750
761
751
- if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
752
- self .write_packet (data_init )
762
+ if self ._ssl_context and \
763
+ (self ._implicit_tls or self .server_capabilities & CLIENT .SSL ):
764
+ if not self ._implicit_tls :
765
+ self .write_packet (data_init )
753
766
754
767
# Stop sending events to data_received
755
768
self ._writer .transport .pause_reading ()
@@ -771,6 +784,9 @@ async def _request_authentication(self):
771
784
server_hostname = self ._host
772
785
)
773
786
787
+ if self ._implicit_tls :
788
+ await self ._get_server_information ()
789
+
774
790
self ._secure = True
775
791
776
792
if isinstance (self .user , str ):
0 commit comments