diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index c153fcda..0b30c582 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -14,7 +14,7 @@ BeginLoadQueryEvent, ExecuteLoadQueryEvent, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, MariadbAnnotateRowsEvent, RandEvent, MariadbStartEncryptionEvent, RowsQueryLogEvent, - MariadbGtidListEvent, MariadbBinLogCheckPointEvent) + MariadbGtidListEvent, MariadbBinLogCheckPointEvent, UserVarEvent) from .exceptions import BinLogNotEnabled from .gtid import GtidSet from .packet import BinLogPacketWrapper @@ -626,7 +626,8 @@ def _allowed_event_list(self, only_events, ignored_events, RandEvent, MariadbStartEncryptionEvent, MariadbGtidListEvent, - MariadbBinLogCheckPointEvent + MariadbBinLogCheckPointEvent, + UserVarEvent )) if ignored_events is not None: for e in ignored_events: diff --git a/pymysqlreplication/event.py b/pymysqlreplication/event.py index 12db2915..7cc3fdaf 100644 --- a/pymysqlreplication/event.py +++ b/pymysqlreplication/event.py @@ -3,8 +3,10 @@ import binascii import struct import datetime +import decimal from pymysqlreplication.constants.STATUS_VAR_KEY import * from pymysqlreplication.exceptions import StatusVariableMismatch +from typing import Union, Optional class BinLogEvent(object): @@ -51,7 +53,6 @@ def _dump(self): """Core data dumped for the event""" pass - class GtidEvent(BinLogEvent): """GTID change in binlog event """ @@ -519,7 +520,6 @@ class RandEvent(BinLogEvent): :ivar seed1: int - value for the first seed :ivar seed2: int - value for the second seed """ - def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): super().__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) @@ -542,6 +542,149 @@ def _dump(self): print("seed1: %d" % (self.seed1)) print("seed2: %d" % (self.seed2)) +class UserVarEvent(BinLogEvent): + """ + UserVarEvent is generated every time a statement uses a user variable. + Indicates the value to use for the user variable in the next statement. + + :ivar name_len: int - Length of user variable + :ivar name: str - User variable name + :ivar value: str - Value of the user variable + :ivar type: int - Type of the user variable + :ivar charset: int - The number of the character set for the user variable + :ivar is_null: int - Non-zero if the variable value is the SQL NULL value, 0 otherwise + :ivar flags: int - Extra flags associated with the user variable + """ + + def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + super(UserVarEvent, self).__init__(from_packet, event_size, table_map, ctl_connection, **kwargs) + + # Payload + self.name_len: int = self.packet.read_uint32() + self.name: str = self.packet.read(self.name_len).decode() + self.is_null: int = self.packet.read_uint8() + self.type_to_codes_and_method: dict = { + 0x00: ['STRING_RESULT', self._read_string], + 0x01: ['REAL_RESULT', self._read_real], + 0x02: ['INT_RESULT', self._read_int], + 0x03: ['ROW_RESULT', self._read_default], + 0x04: ['DECIMAL_RESULT', self._read_decimal] + } + + self.value: Optional[Union[str, float, int, decimal.Decimal]] = None + self.flags: Optional[int] = None + self.temp_value_buffer: Union[bytes, memoryview] = b'' + + if not self.is_null: + self.type: int = self.packet.read_uint8() + self.charset: int = self.packet.read_uint32() + self.value_len: int = self.packet.read_uint32() + self.temp_value_buffer: Union[bytes, memoryview] = self.packet.read(self.value_len) + self.flags: int = self.packet.read_uint8() + self._set_value_from_temp_buffer() + else: + self.type, self.charset, self.value_len, self.value, self.flags = None, None, None, None, None + + def _set_value_from_temp_buffer(self): + """ + Set the value from the temporary buffer based on the type code. + """ + if self.temp_value_buffer: + type_code, read_method = self.type_to_codes_and_method.get(self.type, ["UNKNOWN_RESULT", self._read_default]) + if type_code == 'INT_RESULT': + self.value = read_method(self.temp_value_buffer, self.flags) + else: + self.value = read_method(self.temp_value_buffer) + + def _read_string(self, buffer: bytes) -> str: + """ + Read string data. + """ + return buffer.decode() + + def _read_real(self, buffer: bytes) -> float: + """ + Read real data. + """ + return struct.unpack(' int: + """ + Read integer data. + """ + fmt = ' decimal.Decimal: + """ + Read decimal data. + """ + self.precision = self.temp_value_buffer[0] + self.decimals = self.temp_value_buffer[1] + raw_decimal = self.temp_value_buffer[2:] + return self._parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals) + + def _read_default(self) -> bytes: + """ + Read default data. + Used when the type is None. + """ + return self.packet.read(self.value_len) + + @staticmethod + def _parse_decimal_from_bytes(raw_decimal: bytes, precision: int, decimals: int) -> decimal.Decimal: + """ + Parse decimal from bytes. + """ + digits_per_integer = 9 + compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4] + integral = precision - decimals + + uncomp_integral, comp_integral = divmod(integral, digits_per_integer) + uncomp_fractional, comp_fractional = divmod(decimals, digits_per_integer) + + res = "-" if not raw_decimal[0] & 0x80 else "" + mask = -1 if res == "-" else 0 + raw_decimal = bytearray([raw_decimal[0] ^ 0x80]) + raw_decimal[1:] + + def decode_decimal_decompress_value(comp_indx, data, mask): + size = compressed_bytes[comp_indx] + if size > 0: + databuff = bytearray(data[:size]) + for i in range(size): + databuff[i] = (databuff[i] ^ mask) & 0xFF + return size, int.from_bytes(databuff, byteorder='big') + return 0, 0 + + pointer, value = decode_decimal_decompress_value(comp_integral, raw_decimal, mask) + res += str(value) + + for _ in range(uncomp_integral): + value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask + res += '%09d' % value + pointer += 4 + + res += "." + + for _ in range(uncomp_fractional): + value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask + res += '%09d' % value + pointer += 4 + + size, value = decode_decimal_decompress_value(comp_fractional, raw_decimal[pointer:], mask) + if size > 0: + res += '%0*d' % (comp_fractional, value) + return decimal.Decimal(res) + + def _dump(self) -> None: + super(UserVarEvent, self)._dump() + print("User variable name: %s" % self.name) + print("Is NULL: %s" % ("Yes" if self.is_null else "No")) + if not self.is_null: + print("Type: %s" % self.type_to_codes_and_method.get(self.type, ['UNKNOWN_TYPE'])[0]) + print("Charset: %s" % self.charset) + print("Value: %s" % self.value) + print("Flags: %s" % self.flags) class MariadbStartEncryptionEvent(BinLogEvent): """ diff --git a/pymysqlreplication/packet.py b/pymysqlreplication/packet.py index 1d2e408b..7a801f13 100644 --- a/pymysqlreplication/packet.py +++ b/pymysqlreplication/packet.py @@ -72,6 +72,7 @@ class BinLogPacketWrapper(object): constants.XA_PREPARE_EVENT: event.XAPrepareEvent, constants.ROWS_QUERY_LOG_EVENT: event.RowsQueryLogEvent, constants.RAND_EVENT: event.RandEvent, + constants.USER_VAR_EVENT: event.UserVarEvent, # row_event constants.UPDATE_ROWS_EVENT_V1: row_event.UpdateRowsEvent, constants.WRITE_ROWS_EVENT_V1: row_event.WriteRowsEvent, diff --git a/pymysqlreplication/tests/test_basic.py b/pymysqlreplication/tests/test_basic.py index cb27dada..411894cd 100644 --- a/pymysqlreplication/tests/test_basic.py +++ b/pymysqlreplication/tests/test_basic.py @@ -27,9 +27,9 @@ def ignoredEvents(self): return [GtidEvent] def test_allowed_event_list(self): - self.assertEqual(len(self.stream._allowed_event_list(None, None, False)), 22) - self.assertEqual(len(self.stream._allowed_event_list(None, None, True)), 21) - self.assertEqual(len(self.stream._allowed_event_list(None, [RotateEvent], False)), 21) + self.assertEqual(len(self.stream._allowed_event_list(None, None, False)), 23) + self.assertEqual(len(self.stream._allowed_event_list(None, None, True)), 22) + self.assertEqual(len(self.stream._allowed_event_list(None, [RotateEvent], False)), 22) self.assertEqual(len(self.stream._allowed_event_list([RotateEvent], None, False)), 1) def test_read_query_event(self): @@ -1009,6 +1009,221 @@ def test_parsing(self): gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac:1-:1") gtid = Gtid("57b70f4e-20d3-11e5-a393-4a63946f7eac::1") +class TestStatementConnectionSetting(base.PyMySQLReplicationTestCase): + def setUp(self): + super(TestStatementConnectionSetting, self).setUp() + self.stream.close() + self.stream = BinLogStreamReader( + self.database, + server_id=1024, + only_events=(RandEvent, UserVarEvent, QueryEvent), + fail_on_table_metadata_unavailable=True + ) + self.execute("SET @@binlog_format='STATEMENT'") + + def test_rand_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data INT NOT NULL, PRIMARY KEY (id))") + self.execute("INSERT INTO test (data) VALUES(RAND())") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_rand_event = self.stream.fetchone() + self.assertIsInstance(expected_rand_event, RandEvent) + self.assertEqual(type(expected_rand_event.seed1), int) + self.assertEqual(type(expected_rand_event.seed2), int) + + def test_user_var_string_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR(50), PRIMARY KEY (id))") + self.execute("SET @test_user_var = 'foo'") + self.execute("INSERT INTO test (data) VALUES(@test_user_var)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var") + self.assertEqual(expected_user_var_event.value, "foo") + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 0) + self.assertEqual(expected_user_var_event.charset, 33) + + def test_user_var_real_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data REAL, PRIMARY KEY (id))") + self.execute("SET @test_user_var = @@timestamp") + self.execute("INSERT INTO test (data) VALUES(@test_user_var)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var") + self.assertIsInstance(expected_user_var_event.value,float) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 1) + self.assertEqual(expected_user_var_event.charset, 33) + + def test_user_var_int_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 INT, data2 INT, data3 INT, PRIMARY KEY (id))") + self.execute("SET @test_user_var1 = 5") + self.execute("SET @test_user_var2 = 0") + self.execute("SET @test_user_var3 = -5") + self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var1") + self.assertEqual(expected_user_var_event.value, 5) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var2") + self.assertEqual(expected_user_var_event.value, 0) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var3") + self.assertEqual(expected_user_var_event.value, -5) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + def test_user_var_int24_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 MEDIUMINT, data2 MEDIUMINT, data3 MEDIUMINT UNSIGNED, PRIMARY KEY (id))") + self.execute("SET @test_user_var1 = 8388607") + self.execute("SET @test_user_var2 = -8388607") + self.execute("SET @test_user_var3 = 16777215") + self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var1") + self.assertEqual(expected_user_var_event.value, 8388607) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var2") + self.assertEqual(expected_user_var_event.value, -8388607) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var3") + self.assertEqual(expected_user_var_event.value, 16777215) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + def test_user_var_longlong_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 BIGINT, data2 BIGINT, data3 BIGINT UNSIGNED, PRIMARY KEY (id))") + self.execute("SET @test_user_var1 = 9223372036854775807") + self.execute("SET @test_user_var2 = -9223372036854775808") + self.execute("SET @test_user_var3 = 18446744073709551615") + self.execute("INSERT INTO test (data1, data2, data3) VALUES(@test_user_var1, @test_user_var2, @test_user_var3)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var1") + self.assertEqual(expected_user_var_event.value, 9223372036854775807) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var2") + self.assertEqual(expected_user_var_event.value, -9223372036854775808) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var3") + self.assertEqual(expected_user_var_event.value, 18446744073709551615) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 2) + self.assertEqual(expected_user_var_event.charset, 33) + + def test_user_var_decimal_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data1 DECIMAL, data2 DECIMAL, PRIMARY KEY (id))") + self.execute("SET @test_user_var1 = 5.25") + self.execute("SET @test_user_var2 = -5.25") + self.execute("INSERT INTO test (data1, data2) VALUES(@test_user_var1, @test_user_var2)") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var1") + self.assertEqual(expected_user_var_event.value, 5.25) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 4) + self.assertEqual(expected_user_var_event.charset, 33) + + expected_user_var_event = self.stream.fetchone() + self.assertIsInstance(expected_user_var_event, UserVarEvent) + self.assertIsInstance(expected_user_var_event.name_len, int) + self.assertEqual(expected_user_var_event.name, "test_user_var2") + self.assertEqual(expected_user_var_event.value, -5.25) + self.assertEqual(expected_user_var_event.is_null, 0) + self.assertEqual(expected_user_var_event.type, 4) + self.assertEqual(expected_user_var_event.charset, 33) + + def tearDown(self): + self.execute("SET @@binlog_format='ROW'") + self.assertEqual(self.bin_log_format(), "ROW") + super(TestStatementConnectionSetting, self).tearDown() + class TestMariadbBinlogStreamReader(base.PyMySQLReplicationMariaDbTestCase): def test_binlog_checkpoint_event(self): self.stream.close() @@ -1065,7 +1280,6 @@ def test_annotate_rows_event(self): #Check self.sql_statement self.assertEqual(event.sql_statement,insert_query) self.assertIsInstance(event,MariadbAnnotateRowsEvent) - def test_start_encryption_event(self): query = "CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data VARCHAR (50) NOT NULL, PRIMARY KEY (id))" self.execute(query) @@ -1133,39 +1347,6 @@ def test_gtid_list_event(self): event = self.stream.fetchone() self.assertEqual(event.event_type,163) self.assertEqual(event.gtid_list[0].gtid, '0-1-15') - - - -class TestStatementConnectionSetting(base.PyMySQLReplicationTestCase): - def setUp(self): - super().setUp() - self.stream.close() - self.stream = BinLogStreamReader( - self.database, - server_id=1024, - only_events=(RandEvent, QueryEvent), - fail_on_table_metadata_unavailable=True - ) - self.execute("SET @@binlog_format='STATEMENT'") - - def test_rand_event(self): - self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data INT NOT NULL, PRIMARY KEY (id))") - self.execute("INSERT INTO test (data) VALUES(RAND())") - self.execute("COMMIT") - - self.assertEqual(self.bin_log_format(), "STATEMENT") - self.assertIsInstance(self.stream.fetchone(), QueryEvent) - self.assertIsInstance(self.stream.fetchone(), QueryEvent) - - expect_rand_event = self.stream.fetchone() - self.assertIsInstance(expect_rand_event, RandEvent) - self.assertEqual(type(expect_rand_event.seed1), int) - self.assertEqual(type(expect_rand_event.seed2), int) - - def tearDown(self): - self.execute("SET @@binlog_format='ROW'") - self.assertEqual(self.bin_log_format(), "ROW") - super().tearDown() class TestRowsQueryLogEvents(base.PyMySQLReplicationTestCase):