Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Developed UserVarEvent and Added Statement-Based Logging Test #466

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pymysqlreplication/binlogstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
BeginLoadQueryEvent, ExecuteLoadQueryEvent,
HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent,
MariadbAnnotateRowsEvent, RandEvent, MariadbStartEncryptionEvent, RowsQueryLogEvent,
MariadbGtidListEvent, MariadbBinLogCheckPointEvent)
MariadbGtidListEvent, MariadbBinLogCheckPointEvent, UserVarEvent)
from .exceptions import BinLogNotEnabled
from .gtid import GtidSet
from .packet import BinLogPacketWrapper
Expand Down Expand Up @@ -626,7 +626,8 @@ def _allowed_event_list(self, only_events, ignored_events,
RandEvent,
MariadbStartEncryptionEvent,
MariadbGtidListEvent,
MariadbBinLogCheckPointEvent
MariadbBinLogCheckPointEvent,
UserVarEvent
))
if ignored_events is not None:
for e in ignored_events:
Expand Down
147 changes: 145 additions & 2 deletions pymysqlreplication/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import binascii
import struct
import datetime
import decimal
from pymysqlreplication.constants.STATUS_VAR_KEY import *
from pymysqlreplication.exceptions import StatusVariableMismatch
from typing import Union, Optional


class BinLogEvent(object):
Expand Down Expand Up @@ -51,7 +53,6 @@ def _dump(self):
"""Core data dumped for the event"""
pass


class GtidEvent(BinLogEvent):
"""GTID change in binlog event
"""
Expand Down Expand Up @@ -519,7 +520,6 @@ class RandEvent(BinLogEvent):
:ivar seed1: int - value for the first seed
:ivar seed2: int - value for the second seed
"""

def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super().__init__(from_packet, event_size, table_map,
ctl_connection, **kwargs)
Expand All @@ -542,6 +542,149 @@ def _dump(self):
print("seed1: %d" % (self.seed1))
print("seed2: %d" % (self.seed2))

class UserVarEvent(BinLogEvent):
"""
UserVarEvent is generated every time a statement uses a user variable.
Indicates the value to use for the user variable in the next statement.

:ivar name_len: int - Length of user variable
:ivar name: str - User variable name
:ivar value: str - Value of the user variable
:ivar type: int - Type of the user variable
:ivar charset: int - The number of the character set for the user variable
:ivar is_null: int - Non-zero if the variable value is the SQL NULL value, 0 otherwise
:ivar flags: int - Extra flags associated with the user variable
"""

def __init__(self, from_packet, event_size, table_map, ctl_connection, **kwargs):
super(UserVarEvent, self).__init__(from_packet, event_size, table_map, ctl_connection, **kwargs)

# Payload
self.name_len: int = self.packet.read_uint32()
self.name: str = self.packet.read(self.name_len).decode()
self.is_null: int = self.packet.read_uint8()
self.type_to_codes_and_method: dict = {
0x00: ['STRING_RESULT', self._read_string],
0x01: ['REAL_RESULT', self._read_real],
0x02: ['INT_RESULT', self._read_int],
0x03: ['ROW_RESULT', self._read_default],
0x04: ['DECIMAL_RESULT', self._read_decimal]
}

self.value: Optional[Union[str, float, int, decimal.Decimal]] = None
self.flags: Optional[int] = None
self.temp_value_buffer: Union[bytes, memoryview] = b''

if not self.is_null:
self.type: int = self.packet.read_uint8()
self.charset: int = self.packet.read_uint32()
self.value_len: int = self.packet.read_uint32()
self.temp_value_buffer: Union[bytes, memoryview] = self.packet.read(self.value_len)
self.flags: int = self.packet.read_uint8()
self._set_value_from_temp_buffer()
else:
self.type, self.charset, self.value_len, self.value, self.flags = None, None, None, None, None

def _set_value_from_temp_buffer(self):
"""
Set the value from the temporary buffer based on the type code.
"""
if self.temp_value_buffer:
type_code, read_method = self.type_to_codes_and_method.get(self.type, ["UNKNOWN_RESULT", self._read_default])
if type_code == 'INT_RESULT':
self.value = read_method(self.temp_value_buffer, self.flags)
else:
self.value = read_method(self.temp_value_buffer)

def _read_string(self, buffer: bytes) -> str:
"""
Read string data.
"""
return buffer.decode()

def _read_real(self, buffer: bytes) -> float:
"""
Read real data.
"""
return struct.unpack('<d', buffer)[0]

def _read_int(self, buffer: bytes, flags: int) -> int:
"""
Read integer data.
"""
fmt = '<Q' if flags == 1 else '<q'
return struct.unpack(fmt, buffer)[0]

def _read_decimal(self, buffer: bytes) -> decimal.Decimal:
"""
Read decimal data.
"""
self.precision = self.temp_value_buffer[0]
self.decimals = self.temp_value_buffer[1]
raw_decimal = self.temp_value_buffer[2:]
return self._parse_decimal_from_bytes(raw_decimal, self.precision, self.decimals)

def _read_default(self) -> bytes:
"""
Read default data.
Used when the type is None.
"""
return self.packet.read(self.value_len)

@staticmethod
def _parse_decimal_from_bytes(raw_decimal: bytes, precision: int, decimals: int) -> decimal.Decimal:
"""
Parse decimal from bytes.
"""
digits_per_integer = 9
compressed_bytes = [0, 1, 1, 2, 2, 3, 3, 4, 4, 4]
integral = precision - decimals

uncomp_integral, comp_integral = divmod(integral, digits_per_integer)
uncomp_fractional, comp_fractional = divmod(decimals, digits_per_integer)

res = "-" if not raw_decimal[0] & 0x80 else ""
mask = -1 if res == "-" else 0
raw_decimal = bytearray([raw_decimal[0] ^ 0x80]) + raw_decimal[1:]

def decode_decimal_decompress_value(comp_indx, data, mask):
size = compressed_bytes[comp_indx]
if size > 0:
databuff = bytearray(data[:size])
for i in range(size):
databuff[i] = (databuff[i] ^ mask) & 0xFF
return size, int.from_bytes(databuff, byteorder='big')
return 0, 0

pointer, value = decode_decimal_decompress_value(comp_integral, raw_decimal, mask)
res += str(value)

for _ in range(uncomp_integral):
value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask
res += '%09d' % value
pointer += 4

res += "."

for _ in range(uncomp_fractional):
value = struct.unpack('>i', raw_decimal[pointer:pointer+4])[0] ^ mask
res += '%09d' % value
pointer += 4

size, value = decode_decimal_decompress_value(comp_fractional, raw_decimal[pointer:], mask)
if size > 0:
res += '%0*d' % (comp_fractional, value)
return decimal.Decimal(res)

def _dump(self) -> None:
super(UserVarEvent, self)._dump()
print("User variable name: %s" % self.name)
print("Is NULL: %s" % ("Yes" if self.is_null else "No"))
if not self.is_null:
print("Type: %s" % self.type_to_codes_and_method.get(self.type, ['UNKNOWN_TYPE'])[0])
print("Charset: %s" % self.charset)
print("Value: %s" % self.value)
print("Flags: %s" % self.flags)

class MariadbStartEncryptionEvent(BinLogEvent):
"""
Expand Down
1 change: 1 addition & 0 deletions pymysqlreplication/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BinLogPacketWrapper(object):
constants.XA_PREPARE_EVENT: event.XAPrepareEvent,
constants.ROWS_QUERY_LOG_EVENT: event.RowsQueryLogEvent,
constants.RAND_EVENT: event.RandEvent,
constants.USER_VAR_EVENT: event.UserVarEvent,
# row_event
constants.UPDATE_ROWS_EVENT_V1: row_event.UpdateRowsEvent,
constants.WRITE_ROWS_EVENT_V1: row_event.WriteRowsEvent,
Expand Down
Loading