帮助中心/最新通知

质量为本、客户为根、勇于拼搏、务实创新

< 返回文章列表

【开发相关】使用python实现自定义支持mysql协议的SQL改写中间件

发表时间:2025-01-16 01:32:56 小编:主机乐-Yutio

导读

有时候,我们需要做下SQL改写, mysql相关的插件--rewriter.so来实现, 很好用.

但有那么一种支持mysql协议的非mysql数据库也想使用这个功能, 比如:某业务需要查询information_schema.keywords, 但该数据库并没有实现information_schema.keywords(兼容了,但不完全兼容), 而且业务无法修改代码(可能都没得源码), 这时候就可以使用SQL改写来实现兼容性了.

这种"定制的需求",找原厂的话, 流程太慢,而且大概率不会去实现. 那咋办呢?

  1. 修改数据库实现兼容性, 这个方案已经排除了
  2. 修改业务端代码, 最简单, 但本次情况特殊排除
  3. 修改驱动, 可以在驱动发送SQL之前将其改写; 可行,但连代码都无法修改的情况就别指望能修改驱动了.
  4. 加个中间件,由中间件去实现SQL改写

我们之前已经使用自定义中间件实现读写分离,脱敏等功能了, 来个SQL改写自然也不在话下!

实现原理

原理比较简单, 就是加个中间件, 让业务连接中间件, 由中间件去转发业务数据报, 若遇到需要改写的SQL,顺便改写下即可.

连接过程

在此之前,我们先来简单回顾下mysql的连接过程:

连接流程为

  1. 建立tcp连接
  2. 服务端发送Handshake包给客户端(含密码插件,salt,服务器版本等信息)
  3. 客户端根据密码插件加密密码并发送给服务端
  4. 服务端验证密码之后, 返回连接成功/失败.

数据包结构

mysql的数据包格式为: pack_headr+payload

对象

大小

描述

pack_headr

4

包头

payload

n

包体

其中包头由3字节的payload_length和1字节的seq组成, 而不同的包, 包体都不一样,但通常第一字节都为包类型.连接过程的包如下:

SQL包

而我们本次需要改写的则是SQL包, 对应为: COM_QUERY

我们查询官网得到其payload结构如下:

看起来有丢丢复杂, 但只要不考虑CLIENT_QUERY_ATTRIBUTES, 实际上就是0x03+SQL就没了. 主要是复杂在这个CLIENT_QUERY_ATTRIBUTES上, 客户端HandshakeResponse包的前4个字节client_flag中就是记录是否含CLIENT_QUERY_ATTRIBUTES的, 详情如下:

代码语言:txt
AI代码解释
复制
# 客户端和服务端通用的
CLIENT_LONG_PASSWORD = 1 << 0  #旧密码插件 Use the improved version of Old Password Authentication
CLIENT_FOUND_ROWS = 1 << 1  #Send found rows instead of affected rows in EOF_Packet
CLIENT_LONG_FLAG = 1 << 2 # for ColumnDefinition320
CLIENT_CONNECT_WITH_DB = 1 << 3 #是否带有 dbname 
CLIENT_NO_SCHEMA = 1 << 4 #已弃用. (不允许使用db.table.col)
CLIENT_COMPRESS = 1 << 5 #是否支持压缩
CLIENT_ODBC = 1 << 6 #odbc
CLIENT_LOCAL_FILES = 1 << 7 #能否使用 LOAD DATA LOCAL
CLIENT_IGNORE_SPACE = 1 << 8 #是否忽略 括号( 前面的空格
CLIENT_PROTOCOL_41 = 1 << 9 #是否使用CLIENT_PROTOCOL_41
CLIENT_INTERACTIVE = 1 << 10 #是否为交互式终端(就是mysql连接的那种)
CLIENT_SSL = 1 << 11 #是否支持SSL
CLIENT_IGNORE_SIGPIPE = 1 << 12 #网络故障的时候发SIGPIPE
CLIENT_TRANSACTIONS = 1 << 13 #OK/EOF包的status_flags
CLIENT_RESERVED  = 1 << 14 #已弃用 
CLIENT_RESERVED2 = 1 << 15 #已弃用 
CLIENT_MULTI_STATEMENTS = 1 << 16 #是否支持multi-stmt.  COM_QUERY/COM_STMT_PREPARE中多条语句
CLIENT_MULTI_RESULTS = 1 << 17 #multi-results
CLIENT_PS_MULTI_RESULTS = 1 << 18 #PS-protocol
CLIENT_PLUGIN_AUTH = 1 << 19 #是否支持密码插件
CLIENT_CONNECT_ATTRS = 1 << 20 #是否支持连接属性
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21 #密码认证包能否大于255字节
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 1 << 22 #不关闭密码过期的连接. 我要改密码..
CLIENT_SESSION_TRACK = 1 << 23 #能够处理服务器状态变更信息
CLIENT_DEPRECATE_EOF = 1 << 24 #OK包代替EOF包.  小坑...
CLIENT_OPTIONAL_RESULTSET_METADATA = 1 << 25 #客户端能处理可选元数据信息
CLIENT_ZSTD_COMPRESSION_ALGORITHM = 1 << 26 #zstd压缩
CLIENT_QUERY_ATTRIBUTES = 1 << 27 #支持COM_QUERY/COM_STMT_EXECUTE中的可选参数
MULTI_FACTOR_AUTHENTICATION = 1 << 28 
CLIENT_CAPABILITY_EXTENSION  = 1 << 29 
CLIENT_SSL_VERIFY_SERVER_CERT = 1 << 30 #验证服务器证书
CLIENT_REMEMBER_OPTIONS = 1 << 31 

没有CLIENT_QUERY_ATTRIBUTES的时候的COM_QUERY包结构如下:

对象

大小

描述

pack_header

4

包头

type

1

包类型, 0x03表示com_query

sql

n

sql语句

当有CLIENT_QUERY_ATTRIBUTES的时候的COM_QUERY包结构如下:

对象

大小

描述

pack_header

4

包头

type

1

包类型, 0x03表示com_query

parameter_count

1-9

参数/属性个数

parameter_set_count

1-9

参数集个数,固定为1

null_bitmap

(parameter_count+7)//8

每个参数1bit表示

new_params_bind_flag

1

固定为1,否则报错:(1835, 'Malformed communication packet.')

param_type_and_flag

2

参数名类型,每个参数都有

parameter name size

1-9

参数名大小

parameter name

parameter name size

参数名

parameter_values size

1-9

参数值大小

parameter_values

parameter_values size

参数值

sql

n

sql语句

这里的参数和值是, 参数解析完之后再解析值

可能有的小伙伴并不知道这个CLIENT_QUERY_ATTRIBUTES到底是个啥? 有毛用?

我们在mysql客户端中使用help即可看到query_attributes

看起来是没啥用的, 我们可以简单抓包测试下:

代码语言:sql
AI代码解释
复制
query_attributes key1 123 key2 456
select 1+1;

抓包结果如下:

也就是设置之后,第一条SQL发送的包中会带有这个信息, 至于干啥的俺就不知道了.

sql改写

在我们能解析这个包之后, 我们就能将解析出来的SQL按照某个规则改写为我们需要的效果. 为了简单, 我这里就直接使用字典映射了.

核心逻辑如下:

代码语言:python
AI代码解释
复制
RWSQLD = {
'原始SQL二进制':'改写之后的SQL二进制',
b'select * from information_schema.keywords':b'select * from mysql.help_keyword'
}

.....
        def handler_msg_c2s(self,rf,sock,CLIENT_QUERY_ATTRIBUTES):
                while True:
                        bdata = read_pack(rf)
                        #print(bdata[4:])
                        if bdata[4:5] == b'\x03': # COM_QUERY
                                start = 5
                                if CLIENT_QUERY_ATTRIBUTES: # have attr
                                        if bdata[5:7] == b'\x00\x01':
                                                start = 7
                                        else:
                                                kl,vl,start = read_query_attributes(bdata)
                                sql = bdata[start:]
                                if sql in RWSQLD:
                                        newsql = RWSQLD[sql]
                                        bdata = struct.pack('<L',len(newsql)+start-4)[:3] + bdata[3:start] + newsql
                                        print(time.time(),sql,'->',newsql)
                        sock.sendall(bdata)

我这里则是直接将'select from information_schema.keywords'语句转为'select from mysql.help_keyword'

验证

不扯了, 直接看效果:

mysql环境验证

information_schema.keywords下的信息应该只有700+, 而我们查询出来的有900+, 说明确实是重写成功了.

ob环境验证

mysql本身就有插件的, 所以这次主要是验证ob之类的兼容mysql的数据库. 但是ob中的mysql.help_keyword本来就是空的(兼容了,但不完全兼容), 所以我们可以创建一个其它表, 比如db1.help_keyword,再将sql改写为查询这个表的,测试如下:

看起来是没得问题了

总结

通过自定义中间件能实现很多功能, 本次是实现的SQL改写. 能实现这些的前提是属性mysql协议之类的基础知识, 这些知识官网都有写的,"MySQL Internals Manual" 也有相关记录.

打好基础很重要!!!

参考:

https://dev.mysql.com/doc/refman/8.0/en/rewriter-query-rewrite-plugin-installation.html

https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html

https://dev.mysql.com/doc/refman/8.0/en/mysql-commands.html

附脚本

这里没有写参数接口, 各信息需要自己在脚本里面修改.

这套脚本是祖传的

代码语言:python
AI代码解释
复制
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# writen by ddcw @https://github.com/ddcw
# 中间件实现SQL改写

import struct
from threading import Thread
from multiprocessing import Process
import socket
import time
import sys
import ssl

# 可修改参数
LISTEN_HOST = '0.0.0.0'
LISTEN_PORT = 3306
SERVER_HOST = '127.0.0.1'
SERVER_PORT = 3314
SSL_CERT = '/data/mysql_3314/mysqldata/server-cert.pem'
SSL_KEY = '/data/mysql_3314/mysqldata/server-key.pem'

# sql改写规则,简单点, 就直接做成dict
RWSQLD = {
'原始SQL二进制':'改写之后的SQL二进制',
b'select * from information_schema.keywords':b'select * from db1.help_keyword'
}

# 不可修改的代码
def read_pack(rf):
	pack_header = rf.read(4)
	if len(pack_header) < 4:
		sys.exit(2)
	btrl, btrh, packet_seq = struct.unpack("<HBB", pack_header)
	pack_size = btrl + (btrh << 16)
	bdata = rf.read(pack_size)
	return pack_header+bdata

def read_lenenc(bdata):
	f1, = struct.unpack('<B',bdata[:1])
	if f1 <= 250:
		return f1,1
	elif f1 == 251:
		return struct.unpack('<H',bdata[1:3])[0],3
	elif f1 == 252:
		return struct.unpack('<L',bdata[1:4]+b'\x00')[0],4
	elif f1 == 253:
		return struct.unpack('<Q',bdata[1:9])[0],9

def read_query_attributes(bdata):
	offset = 5
	kl = []
	vl = []
	parameter_count,toffset = read_lenenc(bdata[offset:])
	offset += toffset
	parameter_set_count,toffset = read_lenenc(bdata[offset:])
	offset += toffset
	if parameter_count > 0:
		null_bitmask_size = (parameter_count+7)//8
		null_bitmask = bdata[offset:offset+null_bitmask_size]
		offset += null_bitmask_size
		new_params_bind_flag,toffset = read_lenenc(bdata[offset:])
		offset += toffset
		for x in range(parameter_count):
			ktype = bdata[offset:offset+2]
			offset += 2
			ksize,toffset = read_lenenc(bdata[offset:])
			offset += toffset
			kname = bdata[offset:offset+ksize]
			offset += ksize
			kl.append(kname)
		for x in range(parameter_count):
			vsize,toffset = read_lenenc(bdata[offset:])
			offset += toffset
			value = bdata[offset:offset+vsize]
			offset += vsize
			vl.append(value)
	return kl,vl,offset

class rewritesql(object): 
	def __init__(self):
		self.host = LISTEN_HOST
		self.port = LISTEN_PORT
		self.server = (SERVER_HOST,SERVER_PORT)
		self.cert = SSL_CERT
		self.key = SSL_KEY
		
	def handler_msg_c2s(self,rf,sock,CLIENT_QUERY_ATTRIBUTES):
		while True:
			bdata = read_pack(rf)
			#print(bdata[4:])
			if bdata[4:5] == b'\x03': # COM_QUERY
				start = 5
				if CLIENT_QUERY_ATTRIBUTES: # have attr
					if bdata[5:7] == b'\x00\x01':
						start = 7
					else:
						kl,vl,start = read_query_attributes(bdata)
				sql = bdata[start:]
				if sql in RWSQLD:
					newsql = RWSQLD[sql]
					bdata = struct.pack('<L',len(newsql)+start-4)[:3] + bdata[3:start] + newsql
					print(time.time(),sql,'->',newsql)
			sock.sendall(bdata)

	def handler_msg_s2c(self,rf,sock):
		while True:
			bdata = read_pack(rf)
			sock.sendall(bdata)

	def handler(self,conn,addr):
		sock = socket.create_connection((self.server[0], self.server[1]))
		server_rf = sock.makefile('rb')
		bdata = read_pack(server_rf)
		offset = bdata[5:].find(b'\x00')+5 + 1 + 4 + 8 + 1
		server_cap, = struct.unpack('<L',bdata[offset:offset+2] + bdata[offset+3:offset+5])
		conn.sendall(bdata)

		client_rf = conn.makefile('rb')
		bdata = read_pack(client_rf) 
		client_cap, = struct.unpack('<L',bdata[4:8])
		print('S:',server_cap,'C:',client_cap)
		sock.sendall(bdata)
		CLIENT_QUERY_ATTRIBUTES = True if server_cap&(1<<5) > 0 else False

		#if len(bdata) < 38: #封装为SSL (32+4)
		if client_cap & (1 << 11):
			#封装客户端的SSL (因为相对于client, 这是server角色)
			context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
			context.load_cert_chain(certfile=self.cert, keyfile=self.key)
			conn = context.wrap_socket(conn, server_side=True)
			client_rf = conn.makefile('rb')

			#封装到server的SSL
			sock = ssl.wrap_socket(sock)
			server_rf = sock.makefile('rb')
		
		t1 = Process(target=self.handler_msg_c2s,args=(client_rf,sock,CLIENT_QUERY_ATTRIBUTES))
		t2 = Process(target=self.handler_msg_s2c,args=(server_rf,conn))
		t1.start()
		t2.start()
		t1.join()
		t2.join()

	def init(self):
		socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		socket_server.bind((self.host, self.port))
		socket_server.listen(12345) #设置连接数
		self.socket_server = socket_server

		accept_client_thread = Thread(target=self.accept_client,)
		accept_client_thread.start()
		accept_client_thread.join()
		
	def accept_client(self,):
		while True:
			conn, addr = self.socket_server.accept()
			p = Process(target=self.handler,args=(conn,addr),)
			p.start()
	
if __name__ == '__main__':
	rwsql = rewritesql()
	rwsql.init()

联系我们
返回顶部