有时候,我们需要做下SQL改写, mysql有相关的插件--rewriter.so来实现, 很好用.
但有那么一种支持mysql协议的非mysql数据库也想使用这个功能, 比如:某业务需要查询information_schema.keywords, 但该数据库并没有实现information_schema.keywords(兼容了,但不完全兼容), 而且业务无法修改代码(可能都没得源码), 这时候就可以使用SQL改写来实现兼容性了.
这种"定制的需求",找原厂的话, 流程太慢,而且大概率不会去实现. 那咋办呢?
我们之前已经使用自定义中间件实现读写分离,脱敏等功能了, 来个SQL改写自然也不在话下!
原理比较简单, 就是加个中间件, 让业务连接中间件, 由中间件去转发业务数据报, 若遇到需要改写的SQL,顺便改写下即可.

在此之前,我们先来简单回顾下mysql的连接过程:
连接流程为

mysql的数据包格式为: pack_headr+payload
对象 | 大小 | 描述 |
|---|---|---|
pack_headr | 4 | 包头 |
payload | n | 包体 |
其中包头由3字节的payload_length和1字节的seq组成, 而不同的包, 包体都不一样,但通常第一字节都为包类型.连接过程的包如下:

而我们本次需要改写的则是SQL包, 对应为: COM_QUERY
我们查询官网得到其payload结构如下:

看起来有丢丢复杂, 但只要不考虑CLIENT_QUERY_ATTRIBUTES, 实际上就是0x03+SQL就没了. 主要是复杂在这个CLIENT_QUERY_ATTRIBUTES上, 客户端HandshakeResponse包的前4个字节client_flag中就是记录是否含CLIENT_QUERY_ATTRIBUTES的, 详情如下:
# 客户端和服务端通用的
CLIENT_LONG_PASSWORD = 1 << 0 #旧密码插件 Use the improved version of Old Password Authentication
CLIENT_FOUND_ROWS = 1 << 1 #Send found rows instead of affected rows in EOF_Packet
CLIENT_LONG_FLAG = 1 << 2 # for ColumnDefinition320
CLIENT_CONNECT_WITH_DB = 1 << 3 #是否带有 dbname
CLIENT_NO_SCHEMA = 1 << 4 #已弃用. (不允许使用db.table.col)
CLIENT_COMPRESS = 1 << 5 #是否支持压缩
CLIENT_ODBC = 1 << 6 #odbc
CLIENT_LOCAL_FILES = 1 << 7 #能否使用 LOAD DATA LOCAL
CLIENT_IGNORE_SPACE = 1 << 8 #是否忽略 括号( 前面的空格
CLIENT_PROTOCOL_41 = 1 << 9 #是否使用CLIENT_PROTOCOL_41
CLIENT_INTERACTIVE = 1 << 10 #是否为交互式终端(就是mysql连接的那种)
CLIENT_SSL = 1 << 11 #是否支持SSL
CLIENT_IGNORE_SIGPIPE = 1 << 12 #网络故障的时候发SIGPIPE
CLIENT_TRANSACTIONS = 1 << 13 #OK/EOF包的status_flags
CLIENT_RESERVED = 1 << 14 #已弃用
CLIENT_RESERVED2 = 1 << 15 #已弃用
CLIENT_MULTI_STATEMENTS = 1 << 16 #是否支持multi-stmt. COM_QUERY/COM_STMT_PREPARE中多条语句
CLIENT_MULTI_RESULTS = 1 << 17 #multi-results
CLIENT_PS_MULTI_RESULTS = 1 << 18 #PS-protocol
CLIENT_PLUGIN_AUTH = 1 << 19 #是否支持密码插件
CLIENT_CONNECT_ATTRS = 1 << 20 #是否支持连接属性
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21 #密码认证包能否大于255字节
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22 #不关闭密码过期的连接. 我要改密码..
CLIENT_SESSION_TRACK = 1 << 23 #能够处理服务器状态变更信息
CLIENT_DEPRECATE_EOF = 1 << 24 #OK包代替EOF包. 小坑...
CLIENT_OPTIONAL_RESULTSET_METADATA = 1 << 25 #客户端能处理可选元数据信息
CLIENT_ZSTD_COMPRESSION_ALGORITHM = 1 << 26 #zstd压缩
CLIENT_QUERY_ATTRIBUTES = 1 << 27 #支持COM_QUERY/COM_STMT_EXECUTE中的可选参数
MULTI_FACTOR_AUTHENTICATION = 1 << 28
CLIENT_CAPABILITY_EXTENSION = 1 << 29
CLIENT_SSL_VERIFY_SERVER_CERT = 1 << 30 #验证服务器证书
CLIENT_REMEMBER_OPTIONS = 1 << 31 没有CLIENT_QUERY_ATTRIBUTES的时候的COM_QUERY包结构如下:
对象 | 大小 | 描述 |
|---|---|---|
pack_header | 4 | 包头 |
type | 1 | 包类型, 0x03表示com_query |
sql | n | sql语句 |
当有CLIENT_QUERY_ATTRIBUTES的时候的COM_QUERY包结构如下:
对象 | 大小 | 描述 |
|---|---|---|
pack_header | 4 | 包头 |
type | 1 | 包类型, 0x03表示com_query |
parameter_count | 1-9 | 参数/属性个数 |
parameter_set_count | 1-9 | 参数集个数,固定为1 |
null_bitmap | (parameter_count+7)//8 | 每个参数1bit表示 |
new_params_bind_flag | 1 | 固定为1,否则报错:(1835, 'Malformed communication packet.') |
param_type_and_flag | 2 | 参数名类型,每个参数都有 |
parameter name size | 1-9 | 参数名大小 |
parameter name | parameter name size | 参数名 |
parameter_values size | 1-9 | 参数值大小 |
parameter_values | parameter_values size | 参数值 |
sql | n | sql语句 |
这里的参数和值是, 参数解析完之后再解析值
可能有的小伙伴并不知道这个CLIENT_QUERY_ATTRIBUTES到底是个啥? 有毛用?
我们在mysql客户端中使用help即可看到query_attributes

看起来是没啥用的, 我们可以简单抓包测试下:
query_attributes key1 123 key2 456
select 1+1;
抓包结果如下:

也就是设置之后,第一条SQL发送的包中会带有这个信息, 至于干啥的俺就不知道了.
在我们能解析这个包之后, 我们就能将解析出来的SQL按照某个规则改写为我们需要的效果. 为了简单, 我这里就直接使用字典映射了.
核心逻辑如下:
RWSQLD = {
'原始SQL二进制':'改写之后的SQL二进制',
b'select * from information_schema.keywords':b'select * from mysql.help_keyword'
}
.....
def handler_msg_c2s(self,rf,sock,CLIENT_QUERY_ATTRIBUTES):
while True:
bdata = read_pack(rf)
#print(bdata[4:])
if bdata[4:5] == b'\x03': # COM_QUERY
start = 5
if CLIENT_QUERY_ATTRIBUTES: # have attr
if bdata[5:7] == b'\x00\x01':
start = 7
else:
kl,vl,start = read_query_attributes(bdata)
sql = bdata[start:]
if sql in RWSQLD:
newsql = RWSQLD[sql]
bdata = struct.pack('<L',len(newsql)+start-4)[:3] + bdata[3:start] + newsql
print(time.time(),sql,'->',newsql)
sock.sendall(bdata)我这里则是直接将'select from information_schema.keywords'语句转为'select from mysql.help_keyword'
不扯了, 直接看效果:


information_schema.keywords下的信息应该只有700+, 而我们查询出来的有900+, 说明确实是重写成功了.
mysql本身就有插件的, 所以这次主要是验证ob之类的兼容mysql的数据库. 但是ob中的mysql.help_keyword本来就是空的(兼容了,但不完全兼容), 所以我们可以创建一个其它表, 比如db1.help_keyword,再将sql改写为查询这个表的,测试如下:

看起来是没得问题了
通过自定义中间件能实现很多功能, 本次是实现的SQL改写. 能实现这些的前提是属性mysql协议之类的基础知识, 这些知识官网都有写的,"MySQL Internals Manual" 也有相关记录.
打好基础很重要!!!
参考:
https://dev.mysql.com/doc/refman/8.0/en/rewriter-query-rewrite-plugin-installation.html
https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html
https://dev.mysql.com/doc/refman/8.0/en/mysql-commands.html
这里没有写参数接口, 各信息需要自己在脚本里面修改.
这套脚本是祖传的
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# writen by ddcw @https://github.com/ddcw
# 中间件实现SQL改写
import struct
from threading import Thread
from multiprocessing import Process
import socket
import time
import sys
import ssl
# 可修改参数
LISTEN_HOST = '0.0.0.0'
LISTEN_PORT = 3306
SERVER_HOST = '127.0.0.1'
SERVER_PORT = 3314
SSL_CERT = '/data/mysql_3314/mysqldata/server-cert.pem'
SSL_KEY = '/data/mysql_3314/mysqldata/server-key.pem'
# sql改写规则,简单点, 就直接做成dict
RWSQLD = {
'原始SQL二进制':'改写之后的SQL二进制',
b'select * from information_schema.keywords':b'select * from db1.help_keyword'
}
# 不可修改的代码
def read_pack(rf):
pack_header = rf.read(4)
if len(pack_header) < 4:
sys.exit(2)
btrl, btrh, packet_seq = struct.unpack("<HBB", pack_header)
pack_size = btrl + (btrh << 16)
bdata = rf.read(pack_size)
return pack_header+bdata
def read_lenenc(bdata):
f1, = struct.unpack('<B',bdata[:1])
if f1 <= 250:
return f1,1
elif f1 == 251:
return struct.unpack('<H',bdata[1:3])[0],3
elif f1 == 252:
return struct.unpack('<L',bdata[1:4]+b'\x00')[0],4
elif f1 == 253:
return struct.unpack('<Q',bdata[1:9])[0],9
def read_query_attributes(bdata):
offset = 5
kl = []
vl = []
parameter_count,toffset = read_lenenc(bdata[offset:])
offset += toffset
parameter_set_count,toffset = read_lenenc(bdata[offset:])
offset += toffset
if parameter_count > 0:
null_bitmask_size = (parameter_count+7)//8
null_bitmask = bdata[offset:offset+null_bitmask_size]
offset += null_bitmask_size
new_params_bind_flag,toffset = read_lenenc(bdata[offset:])
offset += toffset
for x in range(parameter_count):
ktype = bdata[offset:offset+2]
offset += 2
ksize,toffset = read_lenenc(bdata[offset:])
offset += toffset
kname = bdata[offset:offset+ksize]
offset += ksize
kl.append(kname)
for x in range(parameter_count):
vsize,toffset = read_lenenc(bdata[offset:])
offset += toffset
value = bdata[offset:offset+vsize]
offset += vsize
vl.append(value)
return kl,vl,offset
class rewritesql(object):
def __init__(self):
self.host = LISTEN_HOST
self.port = LISTEN_PORT
self.server = (SERVER_HOST,SERVER_PORT)
self.cert = SSL_CERT
self.key = SSL_KEY
def handler_msg_c2s(self,rf,sock,CLIENT_QUERY_ATTRIBUTES):
while True:
bdata = read_pack(rf)
#print(bdata[4:])
if bdata[4:5] == b'\x03': # COM_QUERY
start = 5
if CLIENT_QUERY_ATTRIBUTES: # have attr
if bdata[5:7] == b'\x00\x01':
start = 7
else:
kl,vl,start = read_query_attributes(bdata)
sql = bdata[start:]
if sql in RWSQLD:
newsql = RWSQLD[sql]
bdata = struct.pack('<L',len(newsql)+start-4)[:3] + bdata[3:start] + newsql
print(time.time(),sql,'->',newsql)
sock.sendall(bdata)
def handler_msg_s2c(self,rf,sock):
while True:
bdata = read_pack(rf)
sock.sendall(bdata)
def handler(self,conn,addr):
sock = socket.create_connection((self.server[0], self.server[1]))
server_rf = sock.makefile('rb')
bdata = read_pack(server_rf)
offset = bdata[5:].find(b'\x00')+5 + 1 + 4 + 8 + 1
server_cap, = struct.unpack('<L',bdata[offset:offset+2] + bdata[offset+3:offset+5])
conn.sendall(bdata)
client_rf = conn.makefile('rb')
bdata = read_pack(client_rf)
client_cap, = struct.unpack('<L',bdata[4:8])
print('S:',server_cap,'C:',client_cap)
sock.sendall(bdata)
CLIENT_QUERY_ATTRIBUTES = True if server_cap&(1<<5) > 0 else False
#if len(bdata) < 38: #封装为SSL (32+4)
if client_cap & (1 << 11):
#封装客户端的SSL (因为相对于client, 这是server角色)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=self.cert, keyfile=self.key)
conn = context.wrap_socket(conn, server_side=True)
client_rf = conn.makefile('rb')
#封装到server的SSL
sock = ssl.wrap_socket(sock)
server_rf = sock.makefile('rb')
t1 = Process(target=self.handler_msg_c2s,args=(client_rf,sock,CLIENT_QUERY_ATTRIBUTES))
t2 = Process(target=self.handler_msg_s2c,args=(server_rf,conn))
t1.start()
t2.start()
t1.join()
t2.join()
def init(self):
socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_server.bind((self.host, self.port))
socket_server.listen(12345) #设置连接数
self.socket_server = socket_server
accept_client_thread = Thread(target=self.accept_client,)
accept_client_thread.start()
accept_client_thread.join()
def accept_client(self,):
while True:
conn, addr = self.socket_server.accept()
p = Process(target=self.handler,args=(conn,addr),)
p.start()
if __name__ == '__main__':
rwsql = rewritesql()
rwsql.init()