Skip to content

Commit 173eb36

Browse files
committed
Add implicit_tls connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL
fixes aio-libs#757
1 parent ab13f94 commit 173eb36

14 files changed

+265
-35
lines changed

.codecov.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
codecov:
22
notify:
3-
after_n_builds: 40
3+
after_n_builds: 6

.github/workflows/ci-cd.yml

+48-15
Original file line numberDiff line numberDiff line change
@@ -406,18 +406,18 @@ jobs:
406406
- ubuntu-latest
407407
py:
408408
- '3.7'
409-
- '3.8'
410-
- '3.9'
411-
- '3.10'
409+
# - '3.8'
410+
# - '3.9'
411+
# - '3.10'
412412
- '3.11-dev'
413413
db:
414414
- [mysql, '5.7']
415415
- [mysql, '8.0']
416-
- [mariadb, '10.3']
417-
- [mariadb, '10.4']
418-
- [mariadb, '10.5']
419-
- [mariadb, '10.6']
420-
- [mariadb, '10.7']
416+
# - [mariadb, '10.3']
417+
# - [mariadb, '10.4']
418+
# - [mariadb, '10.5']
419+
# - [mariadb, '10.6']
420+
# - [mariadb, '10.7']
421421
- [mariadb, '10.8']
422422

423423
fail-fast: false
@@ -449,6 +449,13 @@ jobs:
449449
options: '--name=mysqld'
450450
env:
451451
MYSQL_ROOT_PASSWORD: rootpw
452+
haproxy:
453+
image: haproxytech/haproxy-alpine:2.6
454+
ports:
455+
- 13306:13306
456+
volumes:
457+
- "/tmp/run-${{ join(matrix.db, '-') }}/:/var/lib/haproxy/socket-mount/"
458+
options: '--name=haproxy'
452459

453460
steps:
454461
- name: Setup Python ${{ matrix.py }}
@@ -569,6 +576,18 @@ jobs:
569576
# unfortunately we need this hacky workaround as GitHub Actions service containers can't reference data from our repo.
570577
- name: Prepare mysql
571578
run: |
579+
# we need to ensure that the socket path is readable from haproxy and
580+
# writable for the user running the DB process
581+
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
582+
583+
# inject HAproxy configuration
584+
docker container stop haproxy
585+
586+
docker container cp "${{ github.workspace }}/tests/ssl_resources/haproxy.cfg" haproxy:/usr/local/etc/haproxy/haproxy.cfg
587+
docker container cp "${{ github.workspace }}/tests/ssl_resources/ssl/server-combined.pem" haproxy:/usr/local/etc/haproxy/haproxy.pem
588+
589+
docker container start haproxy
590+
572591
# ensure server is started up
573592
while :
574593
do
@@ -582,9 +601,6 @@ jobs:
582601
docker container cp "${{ github.workspace }}/tests/ssl_resources/tls.cnf" mysqld:/etc/mysql/conf.d/aiomysql-tls.cnf
583602
584603
# use custom socket path
585-
# we need to ensure that the socket path is writable for the user running the DB process in the container
586-
sudo chmod 0777 /tmp/run-${{ join(matrix.db, '-') }}
587-
588604
docker container cp "${{ github.workspace }}/tests/ssl_resources/socket.cnf" mysqld:/etc/mysql/conf.d/aiomysql-socket.cnf
589605
590606
docker container start mysqld
@@ -598,11 +614,28 @@ jobs:
598614
599615
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "SET GLOBAL local_infile=on"
600616
617+
# This should get removed before merging
618+
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "select user()"
619+
mysql -h127.0.0.1 -uroot "-p$MYSQL_ROOT_PASSWORD" -e "select host, user, hex(authentication_string) from mysql.user"
620+
601621
- name: Run tests
602-
run: |
603-
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
604-
timeout --preserve-status --signal=INT --verbose 570s \
605-
pytest --capture=no --verbosity 2 --cov-report term --cov-report xml --cov aiomysql --cov tests ./tests --mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock" --mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
622+
# timeout ensures a more or less clean stop by sending a KeyboardInterrupt which will still provide useful logs
623+
run: >-
624+
timeout
625+
--preserve-status
626+
--signal=INT
627+
--verbose 570s
628+
pytest
629+
--capture=no
630+
--verbosity 2
631+
--cov-report term
632+
--cov-report xml
633+
--cov aiomysql
634+
--cov tests
635+
./tests
636+
--mysql-unix-socket "unix-${{ join(matrix.db, '') }}=/tmp/run-${{ join(matrix.db, '-') }}/mysql.sock"
637+
--mysql-address "tcp-${{ join(matrix.db, '') }}=127.0.0.1:3306"
638+
--mysql-address-tls "tls-${{ join(matrix.db, '') }}=127.0.0.1:13306"
606639
env:
607640
PYTHONUNBUFFERED: 1
608641
timeout-minutes: 10

CHANGES.txt

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ next (unreleased)
1212
| aiomysql now reraises the original exception during connect() if it's not `IOError`, `OSError` or `asyncio.TimeoutError`.
1313
| This was previously always raised as `OperationalError`.
1414

15+
* Add `implicit_tls` connect arg to support non-standard implicit TLS connections, such as Google Cloud SQL #757
16+
1517
0.1.1 (2022-05-08)
1618
^^^^^^^^^^^^^^^^^^
1719

aiomysql/connection.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
import configparser
1111
import getpass
12+
import ssl as ssllib
1213
from functools import partial
1314

1415
from pymysql.charset import charset_by_name, charset_by_id
@@ -53,7 +54,7 @@ def connect(host="localhost", user=None, password="",
5354
connect_timeout=None, read_default_group=None,
5455
autocommit=False, echo=False,
5556
local_infile=False, loop=None, ssl=None, auth_plugin='',
56-
program_name='', server_public_key=None):
57+
program_name='', server_public_key=None, implicit_tls=False):
5758
"""See connections.Connection.__init__() for information about
5859
defaults."""
5960
coro = _connect(host=host, user=user, password=password, db=db,
@@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
6667
read_default_group=read_default_group,
6768
autocommit=autocommit, echo=echo,
6869
local_infile=local_infile, loop=loop, ssl=ssl,
69-
auth_plugin=auth_plugin, program_name=program_name)
70+
auth_plugin=auth_plugin, program_name=program_name,
71+
implicit_tls=implicit_tls)
7072
return _ConnectionContextManager(coro)
7173

7274

@@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
142144
connect_timeout=None, read_default_group=None,
143145
autocommit=False, echo=False,
144146
local_infile=False, loop=None, ssl=None, auth_plugin='',
145-
program_name='', server_public_key=None):
147+
program_name='', server_public_key=None, implicit_tls=False):
146148
"""
147149
Establish a connection to the MySQL database. Accepts several
148150
arguments:
@@ -184,6 +186,9 @@ def __init__(self, host="localhost", user=None, password="",
184186
handshaking with MySQL. (omitted by default)
185187
:param server_public_key: SHA256 authentication plugin public
186188
key value.
189+
:param implicit_tls: Establish TLS immediately, skipping non-TLS
190+
preamble before upgrading to TLS.
191+
(default: False)
187192
:param loop: asyncio loop
188193
"""
189194
self._loop = loop or asyncio.get_event_loop()
@@ -218,6 +223,7 @@ def __init__(self, host="localhost", user=None, password="",
218223
self._auth_plugin_used = ""
219224
self._secure = False
220225
self.server_public_key = server_public_key
226+
self._implicit_tls = implicit_tls
221227
self.salt = None
222228

223229
from . import __version__
@@ -241,7 +247,10 @@ def __init__(self, host="localhost", user=None, password="",
241247
self.use_unicode = use_unicode
242248

243249
self._ssl_context = ssl
244-
if ssl:
250+
# TLS is required when implicit_tls is True
251+
if implicit_tls and not self._ssl_context:
252+
self._ssl_context = ssllib.create_default_context()
253+
if ssl and not implicit_tls:
245254
client_flag |= CLIENT.SSL
246255

247256
self._encoding = charset_by_name(self._charset).encoding
@@ -536,7 +545,8 @@ async def _connect(self):
536545

537546
self._next_seq_id = 0
538547

539-
await self._get_server_information()
548+
if not self._implicit_tls:
549+
await self._get_server_information()
540550
await self._request_authentication()
541551

542552
self.connected_time = self._loop.time()
@@ -738,7 +748,8 @@ async def _execute_command(self, command, sql):
738748

739749
async def _request_authentication(self):
740750
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
741-
if int(self.server_version.split('.', 1)[0]) >= 5:
751+
# FIXME: change this before merge
752+
if self._implicit_tls or int(self.server_version.split('.', 1)[0]) >= 5:
742753
self.client_flag |= CLIENT.MULTI_RESULTS
743754

744755
if self.user is None:
@@ -748,8 +759,10 @@ async def _request_authentication(self):
748759
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
749760
charset_id, b'')
750761

751-
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
752-
self.write_packet(data_init)
762+
if self._ssl_context and \
763+
(self._implicit_tls or self.server_capabilities & CLIENT.SSL):
764+
if not self._implicit_tls:
765+
self.write_packet(data_init)
753766

754767
# Stop sending events to data_received
755768
self._writer.transport.pause_reading()
@@ -771,6 +784,9 @@ async def _request_authentication(self):
771784
server_hostname=self._host
772785
)
773786

787+
if self._implicit_tls:
788+
await self._get_server_information()
789+
774790
self._secure = True
775791

776792
if isinstance(self.user, str):

docs/connection.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Example::
4747
connect_timeout=None, read_default_group=None,
4848
autocommit=False, echo=False
4949
ssl=None, auth_plugin='', program_name='',
50-
server_public_key=None, loop=None)
50+
server_public_key=None, loop=None, implicit_tls=False)
5151

5252
A :ref:`coroutine <coroutine>` that connects to MySQL.
5353

@@ -93,6 +93,11 @@ Example::
9393
``sys.argv[0]`` is no longer passed by default
9494
:param server_public_key: SHA256 authenticaiton plugin public key value.
9595
:param loop: asyncio event loop instance or ``None`` for default one.
96+
:param implicit_tls: Establish TLS immediately, skipping non-TLS
97+
preamble before upgrading to TLS.
98+
(default: False)
99+
100+
.. versionadded:: 0.2
96101
:returns: :class:`Connection` instance.
97102

98103

tests/conftest.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import gc
33
import os
44
import re
5+
import socket
56
import ssl
67
import sys
78

@@ -63,13 +64,26 @@ def pytest_generate_tests(metafunc):
6364

6465
if ":" in addr:
6566
addr = addr.split(":", 1)
66-
mysql_addresses.append((addr[0], int(addr[1])))
67+
mysql_addresses.append((addr[0], int(addr[1]), False))
6768
else:
68-
mysql_addresses.append((addr, 3306))
69+
mysql_addresses.append((addr, 3306, False))
70+
71+
opt_mysql_address_tls =\
72+
list(metafunc.config.getoption("mysql_address_tls"))
73+
for i in range(len(opt_mysql_address_tls)):
74+
if "=" in opt_mysql_address_tls[i]:
75+
label, addr = opt_mysql_address_tls[i].split("=", 1)
76+
ids.append(label)
77+
else:
78+
addr = opt_mysql_address_tls[i]
79+
ids.append("tls{}".format(i))
80+
81+
addr = addr.split(":", 1)
82+
mysql_addresses.append((addr[0], int(addr[1]), True))
6983

7084
# default to connecting to localhost
7185
if len(mysql_addresses) == 0:
72-
mysql_addresses = [("127.0.0.1", 3306)]
86+
mysql_addresses = [("127.0.0.1", 3306, False)]
7387
ids = ["tcp-local"]
7488

7589
assert len(mysql_addresses) == len(set(mysql_addresses)), \
@@ -153,6 +167,12 @@ def pytest_addoption(parser):
153167
default=[],
154168
help="list of addresses to connect to: [name=]host[:port]",
155169
)
170+
parser.addoption(
171+
"--mysql-address-tls",
172+
action="append",
173+
default=[],
174+
help="list of addresses to connect to using implicit TLS: [name=]host:port",
175+
)
156176
parser.addoption(
157177
"--mysql-unix-socket",
158178
action="append",
@@ -249,6 +269,7 @@ def _register_table(table_name):
249269
@pytest.fixture(scope='session')
250270
def mysql_server(mysql_address):
251271
unix_socket = type(mysql_address) is str
272+
implicit_tls = not unix_socket and mysql_address[2]
252273

253274
if not unix_socket:
254275
ssl_directory = os.path.join(os.path.dirname(__file__),
@@ -270,14 +291,34 @@ def mysql_server(mysql_address):
270291
else:
271292
server_params["host"] = mysql_address[0]
272293
server_params["port"] = mysql_address[1]
294+
295+
if not unix_socket and not implicit_tls:
273296
server_params["ssl"] = ctx
274297

275298
try:
276-
connection = pymysql.connect(
277-
db='mysql',
278-
charset='utf8mb4',
279-
cursorclass=pymysql.cursors.DictCursor,
280-
**server_params)
299+
if implicit_tls:
300+
sock = ctx.wrap_socket(
301+
socket.create_connection(
302+
(server_params["host"], server_params["port"]),
303+
),
304+
server_hostname=server_params["host"],
305+
)
306+
connection = pymysql.Connection(
307+
db='mysql',
308+
charset='utf8mb4',
309+
cursorclass=pymysql.cursors.DictCursor,
310+
**server_params,
311+
defer_connect=True,
312+
)
313+
connection.connect(sock)
314+
315+
else:
316+
connection = pymysql.connect(
317+
db='mysql',
318+
charset='utf8mb4',
319+
cursorclass=pymysql.cursors.DictCursor,
320+
**server_params,
321+
)
281322

282323
with connection.cursor() as cursor:
283324
cursor.execute("SELECT VERSION() AS version")
@@ -297,7 +338,7 @@ def mysql_server(mysql_address):
297338
pytest.fail("Unable to determine database type from {!r}"
298339
.format(server_version_tuple))
299340

300-
if not unix_socket:
341+
if not unix_socket and not implicit_tls:
301342
cursor.execute("SHOW VARIABLES LIKE '%ssl%';")
302343

303344
result = cursor.fetchall()
@@ -353,6 +394,10 @@ def mysql_server(mysql_address):
353394
except Exception:
354395
pytest.fail("Cannot initialize MySQL environment")
355396

397+
if implicit_tls:
398+
server_params["ssl"] = ctx
399+
server_params["implicit_tls"] = implicit_tls
400+
356401
return {
357402
"conn_params": server_params,
358403
"server_version": server_version,

tests/sa/test_sa_compiled_cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
2626
}
2727
if "ssl" in mysql_params:
2828
conn_args["ssl"] = mysql_params["ssl"]
29+
if "implicit_tls" in mysql_params:
30+
conn_args["implicit_tls"] = mysql_params["implicit_tls"]
2931

3032
engine = await sa.create_engine(
3133
db=mysql_params['db'],

tests/sa/test_sa_default.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ async def _make_engine(**kwargs):
3333
}
3434
if "ssl" in mysql_params:
3535
conn_args["ssl"] = mysql_params["ssl"]
36+
if "implicit_tls" in mysql_params:
37+
conn_args["implicit_tls"] = mysql_params["implicit_tls"]
3638

3739
engine = await sa.create_engine(
3840
db=mysql_params['db'],

tests/sa/test_sa_engine.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ async def _make_engine(**kwargs):
2626
}
2727
if "ssl" in mysql_params:
2828
conn_args["ssl"] = mysql_params["ssl"]
29+
if "implicit_tls" in mysql_params:
30+
conn_args["implicit_tls"] = mysql_params["implicit_tls"]
2931

3032
engine = await sa.create_engine(
3133
db=mysql_params['db'],

0 commit comments

Comments
 (0)