Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing - 'test_basic.py', 'base.py', 'column.py', 'test_abnormal.py', 'benchmark.py' #82

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions pymysqlreplication/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,36 @@
import struct

from .constants import FIELD_TYPE

from typing import Any, Dict, Optional

class Column(object):
"""Definition of a column
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if len(args) == 3:
self.__parse_column_definition(*args)
else:
self.__dict__.update(kwargs)

def __parse_column_definition(self, column_type, column_schema, packet):
self.type = column_type
self.name = column_schema["COLUMN_NAME"]
self.collation_name = column_schema["COLLATION_NAME"]
self.character_set_name = column_schema["CHARACTER_SET_NAME"]
self.comment = column_schema["COLUMN_COMMENT"]
self.unsigned = column_schema["COLUMN_TYPE"].find("unsigned") != -1
self.zerofill = column_schema["COLUMN_TYPE"].find("zerofill") != -1
self.type_is_bool = False
self.is_primary = column_schema["COLUMN_KEY"] == "PRI"
def __parse_column_definition(self,
column_type: str,
column_schema: Dict[str, Any],
packet: Any
) -> None:
self.type: str = column_type
self.name: str = column_schema["COLUMN_NAME"]
self.collation_name: str = column_schema["COLLATION_NAME"]
self.character_set_name: str = column_schema["CHARACTER_SET_NAME"]
self.comment: str = column_schema["COLUMN_COMMENT"]
self.unsigned: bool = column_schema["COLUMN_TYPE"].find("unsigned") != -1
self.zerofill: bool = column_schema["COLUMN_TYPE"].find("zerofill") != -1
self.type_is_bool: bool = False
self.is_primary: bool = column_schema["COLUMN_KEY"] == "PRI"

# Check for fixed-length binary type. When that's the case then we need
# to zero-pad the values to full length at read time.
self.fixed_binary_length = None
self.fixed_binary_length: Optional[str] = None
if column_schema["DATA_TYPE"] == "binary":
self.fixed_binary_length = column_schema["CHARACTER_OCTET_LENGTH"]

Expand Down Expand Up @@ -65,7 +69,7 @@ def __parse_column_definition(self, column_type, column_schema, packet):
self.bits = (bytes * 8) + bits
self.bytes = int((self.bits + 7) / 8)

def __read_string_metadata(self, packet, column_schema):
def __read_string_metadata(self, packet: Any, column_schema: Dict[str, Any]) -> None:
metadata = (packet.read_uint8() << 8) + packet.read_uint8()
real_type = metadata >> 8
if real_type == FIELD_TYPE.SET or real_type == FIELD_TYPE.ENUM:
Expand All @@ -76,7 +80,7 @@ def __read_string_metadata(self, packet, column_schema):
self.max_length = (((metadata >> 4) & 0x300) ^ 0x300) \
+ (metadata & 0x00ff)

def __read_enum_metadata(self, column_schema):
def __read_enum_metadata(self, column_schema: Dict[str, Any]) -> None:
enums = column_schema["COLUMN_TYPE"]
if self.type == FIELD_TYPE.ENUM:
self.enum_values = [''] + enums.replace('enum(', '')\
Expand All @@ -85,15 +89,15 @@ def __read_enum_metadata(self, column_schema):
self.set_values = enums.replace('set(', '')\
.replace(')', '').replace('\'', '').split(',')

def __eq__(self, other):
def __eq__(self, other: 'Column') -> bool:
return self.data == other.data

def __ne__(self, other):
def __ne__(self, other: 'Column') -> bool:
return not self.__eq__(other)

def serializable_data(self):
def serializable_data(self) -> Dict[str, Any]:
return self.data

@property
def data(self):
def data(self) -> Dict[str, Any]:
return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith('_'))
91 changes: 49 additions & 42 deletions pymysqlreplication/tests/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-

import pymysql
import copy
from pymysqlreplication import BinLogStreamReader
import os
import sys
from pymysql.cursors import Cursor
import pymysql
from pymysql import Connection
from typing import Optional

from pymysqlreplication import BinLogStreamReader

if sys.version_info < (2, 7):
import unittest2 as unittest
Expand All @@ -15,96 +19,99 @@


class PyMySQLReplicationTestCase(base):
def ignoredEvents(self):
def ignoredEvents(self) -> list:
return []

def setUp(self):
def setUp(self) -> None:
# default
self.database = {
self.database: dict = {
"host": os.environ.get("MYSQL_5_7") or "localhost",
"user": "root",
"passwd": "",
"port": 3306,
"use_unicode": True,
"charset": "utf8",
"db": "pymysqlreplication_test"
"db": "pymysqlreplication_test",
}

self.conn_control = None
self.conn_control: Optional[Connection] = None
db = copy.copy(self.database)
db["db"] = None
self.connect_conn_control(db)
self.execute("DROP DATABASE IF EXISTS pymysqlreplication_test")
self.execute("CREATE DATABASE pymysqlreplication_test")
db = copy.copy(self.database)
self.connect_conn_control(db)
self.stream = None
self.stream: Optional[BinLogStreamReader] = None
self.resetBinLog()
self.isMySQL56AndMore()
self.__is_mariaDB = None
self.__is_mariaDB: Optional[bool] = None

def getMySQLVersion(self):
def getMySQLVersion(self) -> str:
"""Return the MySQL version of the server
If version is 5.6.10-log the result is 5.6.10
"""
return self.execute("SELECT VERSION()").fetchone()[0].split('-')[0]
return self.execute("SELECT VERSION()").fetchone()[0].split("-")[0]

def isMySQL56AndMore(self):
version = float(self.getMySQLVersion().rsplit('.', 1)[0])
def isMySQL56AndMore(self) -> bool:
version = float(self.getMySQLVersion().rsplit(".", 1)[0])
if version >= 5.6:
return True
return False

def isMySQL57(self):
version = float(self.getMySQLVersion().rsplit('.', 1)[0])
def isMySQL57(self) -> bool:
version = float(self.getMySQLVersion().rsplit(".", 1)[0])
return version == 5.7

def isMySQL80AndMore(self):
version = float(self.getMySQLVersion().rsplit('.', 1)[0])
def isMySQL80AndMore(self) -> bool:
version = float(self.getMySQLVersion().rsplit(".", 1)[0])
return version >= 8.0

def isMariaDB(self):
def isMariaDB(self) -> bool:
if self.__is_mariaDB is None:
self.__is_mariaDB = "MariaDB" in self.execute("SELECT VERSION()").fetchone()[0]
self.__is_mariaDB: bool = (
"MariaDB" in self.execute("SELECT VERSION()").fetchone()[0]
)
return self.__is_mariaDB

@property
def supportsGTID(self):
def supportsGTID(self) -> bool:
if not self.isMySQL56AndMore():
return False
return self.execute("SELECT @@global.gtid_mode ").fetchone()[0] == "ON"

def connect_conn_control(self, db):
def connect_conn_control(self, db) -> None:
if self.conn_control is not None:
self.conn_control.close()
self.conn_control = pymysql.connect(**db)
self.conn_control: Connection = pymysql.connect(**db)

def tearDown(self):
def tearDown(self) -> None:
self.conn_control.close()
self.conn_control = None
self.conn_control: Optional[Connection] = None
self.stream.close()
self.stream = None
self.stream: Optional[BinLogStreamReader] = None

def execute(self, query):
def execute(self, query: str) -> Cursor:
c = self.conn_control.cursor()
c.execute(query)
return c
def execute_with_args(self, query, args):

def execute_with_args(self, query: str, args) -> Cursor:
c = self.conn_control.cursor()
c.execute(query, args)
return c

def resetBinLog(self):
def resetBinLog(self) -> None:
self.execute("RESET MASTER")
if self.stream is not None:
self.stream.close()
self.stream = BinLogStreamReader(self.database, server_id=1024,
ignored_events=self.ignoredEvents())
self.stream: BinLogStreamReader = BinLogStreamReader(
self.database, server_id=1024, ignored_events=self.ignoredEvents()
)

def set_sql_mode(self):
def set_sql_mode(self) -> None:
"""set sql_mode to test with same sql_mode (mysql 5.7 sql_mode default is changed)"""
version = float(self.getMySQLVersion().rsplit('.', 1)[0])
version = float(self.getMySQLVersion().rsplit(".", 1)[0])
if version == 5.7:
self.execute("SET @@sql_mode='NO_ENGINE_SUBSTITUTION'")

Expand All @@ -114,15 +121,15 @@ def bin_log_format(self):
result = cursor.fetchone()
return result[0]

def bin_log_basename(self):
cursor = self.execute('SELECT @@log_bin_basename')
def bin_log_basename(self) -> str:
cursor: Cursor = self.execute("SELECT @@log_bin_basename")
bin_log_basename = cursor.fetchone()[0]
bin_log_basename = bin_log_basename.split("/")[-1]
return bin_log_basename


class PyMySQLReplicationMariaDbTestCase(PyMySQLReplicationTestCase):
def setUp(self):
def setUp(self) -> None:
# default
self.database = {
"host": os.environ.get("MARIADB_10_6") or "localhost",
Expand All @@ -131,22 +138,22 @@ def setUp(self):
"port": int(os.environ.get("MARIADB_10_6_PORT") or 3308),
"use_unicode": True,
"charset": "utf8",
"db": "pymysqlreplication_test"
"db": "pymysqlreplication_test",
}

self.conn_control = None
self.conn_control: Optional[Connection] = None
db = copy.copy(self.database)
db["db"] = None
self.connect_conn_control(db)
self.execute("DROP DATABASE IF EXISTS pymysqlreplication_test")
self.execute("CREATE DATABASE pymysqlreplication_test")
db = copy.copy(self.database)
self.connect_conn_control(db)
self.stream = None
self.stream: Optional[BinLogStreamReader] = None
self.resetBinLog()
def bin_log_basename(self):
cursor = self.execute('SELECT @@log_bin_basename')

def bin_log_basename(self) -> str:
cursor: Cursor = self.execute("SELECT @@log_bin_basename")
bin_log_basename = cursor.fetchone()[0]
bin_log_basename = bin_log_basename.split("/")[-1]
return bin_log_basename
12 changes: 8 additions & 4 deletions pymysqlreplication/tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
import os
from pymysqlreplication import BinLogStreamReader
from pymysqlreplication.row_event import *
from pymysql.connections import Connection
from pymysql.cursors import Cursor

from typing import Any
import cProfile


def execute(con, query):
def execute(con: Connection, query: str) -> Cursor:
c = con.cursor()
c.execute(query)
return c

def consume_events():
def consume_events() -> None:
stream = BinLogStreamReader(connection_settings=database,
server_id=3,
resume_stream=False,
Expand All @@ -44,11 +48,11 @@ def consume_events():
"db": "pymysqlreplication_test"
}

conn = pymysql.connect(**database)
conn: Connection = pymysql.connect(**database)

execute(conn, "DROP DATABASE IF EXISTS pymysqlreplication_test")
execute(conn, "CREATE DATABASE pymysqlreplication_test")
conn = pymysql.connect(**database)
conn: Connection = pymysql.connect(**database)
execute(conn, "CREATE TABLE test (i INT) ENGINE = MEMORY")
execute(conn, "INSERT INTO test VALUES(1)")
execute(conn, "CREATE TABLE test2 (i INT) ENGINE = MEMORY")
Expand Down
6 changes: 3 additions & 3 deletions pymysqlreplication/tests/test_abnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def ignored_events():
'''Events the BinLogStreamReader should ignore'''
return [GtidEvent]

def test_no_trailing_rotate_event(self):
def test_no_trailing_rotate_event(self) -> None:
'''A missing RotateEvent and skip_to_timestamp cause corruption

This test shows that a binlog file which lacks the trailing RotateEvent
Expand All @@ -42,7 +42,7 @@ def test_no_trailing_rotate_event(self):

binlog = self.execute("SHOW BINARY LOGS").fetchone()[0]

self.stream = BinLogStreamReader(
self.stream: BinLogStreamReader = BinLogStreamReader(
self.database,
server_id=1024,
log_pos=4,
Expand All @@ -54,7 +54,7 @@ def test_no_trailing_rotate_event(self):
# The table_map should be empty because of the binlog being rotated.
self.assertEqual({}, self.stream.table_map)

def _remove_trailing_rotate_event_from_first_binlog(self):
def _remove_trailing_rotate_event_from_first_binlog(self) -> None:
'''Remove the trailing RotateEvent from the first binlog

According to the MySQL Internals Manual, a RotateEvent will be added to
Expand Down
Loading