帮助中心/最新通知

质量为本、客户为根、勇于拼搏、务实创新

< 返回文章列表

【服务器相关】教你如何让spark sql写mysql的时候支持update操作

发表时间:2025-06-16 03:46:00 小编:主机乐-Yutio

如何让sparkSQL在对接mysql的时候,除了支持:Append、Overwrite、ErrorIfExists、Ignore;还要在支持update操作

1、首先了解背景

spark提供了一个枚举类,用来支撑对接数据源的操作模式

通过源码查看,很明显,spark是不支持update操作的

2、如何让sparkSQL支持update

关键的知识点就是:

我们正常在sparkSQL写数据到mysql的时候:

大概的api是:

然后在出发save()操作后,就开始将数据写入;

接下来看save()源码:

在上面的源码里面主要是注册DataSource实例,然后使用DataSource的write方法进行数据写入

实例化DataSource的时候:

然后看下providingClass是什么:

拿到包路径.DefaultSource之后,程序进入:

那么如果是数据库作为写入目标的话,就会走:dataSource.createRelation,直接跟进源码:

很明显是个特质,因此哪里实现了特质,程序就会走到哪里了;

实现这个特质的地方就是:包路径.DefaultSource , 然后就在这里面去实现数据的插入和update的支持操作;

4、改造源码

根据代码的流程,最终sparkSQL 将数据写入mysql的操作,会进入:包路径.DefaultSource这个类里面;

也就是说,在这个类里面既要支持spark的正常插入操作(SaveMode),还要在支持update;

如果让sparksql支持update操作,最关键的就是做一个判断,比如:

没有任何的判断逻辑,就是最后生成一个:

这样我们就拿到了对应的sql语句;

但是只有这个sql语句还是不行的,因为在spark中会执行jdbc的prepareStatement操作,这里面会涉及到游标。

即jdbc在遍历这个sql的时候,源码会这样做:

看下makeSetter:

所谓有坑就是:

这样的话,后面的update操作就无法执行,程序报错!

所以我们需要有一个 识别机制,既:

row[1,2,3] setter(0,1) //index of setter , index of row setter(1,2) setter(2,3) setter(3,1) setter(4,2) setter(5,3)

所以在prepareStatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑

因此,代码改造前样子:

改造后的样子:


try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
//val stmt = insertStatement(conn, table, rddSchema, dialect)
//此处采用最新自己的sql语句,封装成prepareStatement
val stmt = conn.prepareStatement(sqlStmt)
println(sqlStmt)

//makeSetter也要适配update操作,即游标问题

val isUpdate = saveMode == CustomSaveMode.Update
val setters: Array[JDBCValueSetter] = isUpdate match {
case true =>
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray
Array.fill(2)(setters).flatten
case _ =>
rddSchema.fields.map(_.dataType)
val numFieldsLength = rddSchema.fields.length
val numFields = isUpdate match{
case true => numFieldsLength *2
case _ => numFieldsLength
val cursorBegin = numFields / 2
try {
var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if(isUpdate){
//需要判断当前游标是否走到了ON DUPLICATE KEY UPDATE
i < cursorBegin match{
//说明还没走到update阶段
case true =>
//row.isNullAt 判空,则设置空值
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i, 0)
}
//说明走到了update阶段
case false =>
if (row.isNullAt(i – cursorBegin)) {
//pos – offset
stmt.setNull(i + 1, nullTypes(i – cursorBegin))
setters(i).apply(stmt, row, i, cursorBegin)
}
}else{
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i ,0)
}
//滚动游标
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
}
if (rowCount > 0) {
stmt.executeBatch()
} finally {
stmt.close()
conn.commit()
committed = true
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
if (e.getCause == null) {
e.initCause(cause)
} else {
e.addSuppressed(cause)
throw e
} finally {
if (!committed) {
// The stage must fail.We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
conn.close()
} else {
// The stage must succeed.We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning(“Transaction succeeded, but closing failed”, e)
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
//PreparedStatement, Row, position , cursor
private type JDBCValueSetter = (PreparedStatement, Row, Int , Int) => Unit

private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
stmt.setInt(pos + 1, row.getInt(pos – cursor))
case LongType =>
stmt.setLong(pos + 1, row.getLong(pos – cursor))
case DoubleType =>
stmt.setDouble(pos + 1, row.getDouble(pos – cursor))
case FloatType =>
stmt.setFloat(pos + 1, row.getFloat(pos – cursor))
case ShortType =>
stmt.setInt(pos + 1, row.getShort(pos – cursor))
case ByteType =>
stmt.setInt(pos + 1, row.getByte(pos – cursor))
case BooleanType =>
stmt.setBoolean(pos + 1, row.getBoolean(pos – cursor))
case StringType =>
//println(row.getString(pos))
stmt.setString(pos + 1, row.getString(pos – cursor))
case BinaryType =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos – cursor))
case TimestampType =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos – cursor))
case DateType =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos – cursor))
case t: DecimalType =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos – cursor))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split(“\\(“)(0)
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos – cursor).toArray)
stmt.setArray(pos + 1, array)
case _ =>
(_: PreparedStatement, _: Row, pos: Int,cursor:Int) =>
throw new IllegalArgumentException(
s”Can’t translate non-null value for field $pos”)
}

完整代码:

https://github.com/niutaofan/bazinga

到此这篇关于教你如何让spark sql写mysql的时候支持update操作的文章就介绍到这了,更多相关spark sql写mysql支持update内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!


联系我们
返回顶部