diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 0912850b..89d9fb02 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -15,7 +15,7 @@ XidEvent, GtidEvent, StopEvent, XAPrepareEvent, BeginLoadQueryEvent, ExecuteLoadQueryEvent, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, - MariadbAnnotateRowsEvent) + MariadbAnnotateRowsEvent, RandEvent) from .exceptions import BinLogNotEnabled from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) @@ -621,7 +621,8 @@ def _allowed_event_list(self, only_events, ignored_events, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, - MariadbAnnotateRowsEvent + MariadbAnnotateRowsEvent, + RandEvent )) if ignored_events is not None: for e in ignored_events: diff --git a/pymysqlreplication/event.py b/pymysqlreplication/event.py index 2a51df22..eb0d9221 100644 --- a/pymysqlreplication/event.py +++ b/pymysqlreplication/event.py @@ -454,6 +454,39 @@ def _dump(self): print("type: %d" % (self.type)) print("Value: %d" % (self.value)) +class RandEvent(BinLogEvent): + """ + RandEvent is generated every time a statement uses the RAND() function. + Indicates the seed values to use for generating a random number with RAND() in the next statement. + + RandEvent only works in statement-based logging (need to set binlog_format as 'STATEMENT') + and only works when the seed number is not specified. + + :ivar seed1: int - value for the first seed + :ivar seed2: int - value for the second seed + """ + + def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): + super(RandEvent, self).__init__(from_packet, event_size, table_map, + ctl_connection, **kwargs) + # Payload + self._seed1 = self.packet.read_uint64() + self._seed2 = self.packet.read_uint64() + + @property + def seed1(self): + """Get the first seed value""" + return self._seed1 + + @property + def seed2(self): + """Get the second seed value""" + return self._seed2 + + def _dump(self): + super(RandEvent, self)._dump() + print("seed1: %d" % (self.seed1)) + print("seed2: %d" % (self.seed2)) class NotImplementedEvent(BinLogEvent): def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs): diff --git a/pymysqlreplication/packet.py b/pymysqlreplication/packet.py index fb5a42b5..8c918c16 100644 --- a/pymysqlreplication/packet.py +++ b/pymysqlreplication/packet.py @@ -70,6 +70,7 @@ class BinLogPacketWrapper(object): constants.EXECUTE_LOAD_QUERY_EVENT: event.ExecuteLoadQueryEvent, constants.HEARTBEAT_LOG_EVENT: event.HeartbeatLogEvent, constants.XA_PREPARE_EVENT: event.XAPrepareEvent, + constants.RAND_EVENT: event.RandEvent, # row_event constants.UPDATE_ROWS_EVENT_V1: row_event.UpdateRowsEvent, constants.WRITE_ROWS_EVENT_V1: row_event.WriteRowsEvent, diff --git a/pymysqlreplication/tests/test_basic.py b/pymysqlreplication/tests/test_basic.py index e59a520b..e99b347c 100644 --- a/pymysqlreplication/tests/test_basic.py +++ b/pymysqlreplication/tests/test_basic.py @@ -19,7 +19,7 @@ from pymysqlreplication.constants.BINLOG import * from pymysqlreplication.row_event import * -__all__ = ["TestBasicBinLogStreamReader", "TestMultipleRowBinLogStreamReader", "TestCTLConnectionSettings", "TestGtidBinLogStreamReader","TestMariadbBinlogStreamReader"] +__all__ = ["TestBasicBinLogStreamReader", "TestMultipleRowBinLogStreamReader", "TestCTLConnectionSettings", "TestGtidBinLogStreamReader", "TestMariadbBinlogStreamReader", "TestStatementConnectionSetting"] class TestBasicBinLogStreamReader(base.PyMySQLReplicationTestCase): @@ -27,9 +27,9 @@ def ignoredEvents(self): return [GtidEvent] def test_allowed_event_list(self): - self.assertEqual(len(self.stream._allowed_event_list(None, None, False)), 17) - self.assertEqual(len(self.stream._allowed_event_list(None, None, True)), 16) - self.assertEqual(len(self.stream._allowed_event_list(None, [RotateEvent], False)), 16) + self.assertEqual(len(self.stream._allowed_event_list(None, None, False)), 18) + self.assertEqual(len(self.stream._allowed_event_list(None, None, True)), 17) + self.assertEqual(len(self.stream._allowed_event_list(None, [RotateEvent], False)), 17) self.assertEqual(len(self.stream._allowed_event_list([RotateEvent], None, False)), 1) def test_read_query_event(self): @@ -1036,6 +1036,37 @@ def test_annotate_rows_event(self): self.assertEqual(event.sql_statement,insert_query) self.assertIsInstance(event,MariadbAnnotateRowsEvent) +class TestStatementConnectionSetting(base.PyMySQLReplicationTestCase): + def setUp(self): + super(TestStatementConnectionSetting, self).setUp() + self.stream.close() + self.stream = BinLogStreamReader( + self.database, + server_id=1024, + only_events=(RandEvent, QueryEvent), + fail_on_table_metadata_unavailable=True + ) + self.execute("SET @@binlog_format='STATEMENT'") + + def test_rand_event(self): + self.execute("CREATE TABLE test (id INT NOT NULL AUTO_INCREMENT, data INT NOT NULL, PRIMARY KEY (id))") + self.execute("INSERT INTO test (data) VALUES(RAND())") + self.execute("COMMIT") + + self.assertEqual(self.bin_log_format(), "STATEMENT") + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + self.assertIsInstance(self.stream.fetchone(), QueryEvent) + + expect_rand_event = self.stream.fetchone() + self.assertIsInstance(expect_rand_event, RandEvent) + self.assertEqual(type(expect_rand_event.seed1), int) + self.assertEqual(type(expect_rand_event.seed2), int) + + def tearDown(self): + self.execute("SET @@binlog_format='ROW'") + self.assertEqual(self.bin_log_format(), "ROW") + super(TestStatementConnectionSetting, self).tearDown() + if __name__ == "__main__": import unittest