欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

教你如何让spark sql写mysql的时候支持update操作

程序员文章站 2022-06-03 23:52:05
目录1、首先了解背景2、如何让sparksql支持update3、改造源码前,需要了解整体的代码设计和执行流程4、改造源码如何让sparksql在对接mysql的时候,除了支持:append、over...

教你如何让spark sql写mysql的时候支持update操作

如何让sparksql在对接mysql的时候,除了支持:append、overwrite、errorifexists、ignore;还要在支持update操作

1、首先了解背景

spark提供了一个枚举类,用来支撑对接数据源的操作模式

教你如何让spark sql写mysql的时候支持update操作

通过源码查看,很明显,spark是不支持update操作的

2、如何让sparksql支持update

关键的知识点就是:

我们正常在sparksql写数据到mysql的时候:

大概的api是:

dataframe.write
          .format("sql.execution.customdatasource.jdbc")
          .option("jdbc.driver", "com.mysql.jdbc.driver")
          .option("jdbc.url", "jdbc:mysql://localhost:3306/test?user=root&password=&useunicode=true&characterencoding=gbk&autoreconnect=true&failoverreadonly=false")
          .option("jdbc.db", "test")
          .save()

那么在底层中,spark会通过jdbc方言jdbcdialect , 将我们要插入的数据翻译成:

insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)

那么通过方言解析出的sql语句就通过preparestatement的executebatch(),将sql语句提交给mysql,然后数据插入;

那么上面的sql语句很明显,完全就是插入代码,并没有我们期望的 update操作,类似:

update table_name set field1=new-value1, field2=new-value2

但是mysql独家支持这样的sql语句:

insert into student (columns_1,columns_2)values ('第一个字段值','第二个字段值') on duplicate key update columns_1 = '呵呵哒',columns_2 = '哈哈哒';

大概的意思就是,如果数据不存在则插入,如果数据存在,则 执行update操作;

因此,我们的切入点就是,让sparksql内部对接jdbcdialect的时候,能够生成这种sql:

insert into 表名称 (columns_1,columns_2)values ('第一个字段值','第二个字段值') on duplicate key update columns_1 = '呵呵哒',columns_2 = '哈哈哒';

3、改造源码前,需要了解整体的代码设计和执行流程

首先是:

dataframe.write

调用write方法就是为了返回一个类:dataframewriter

主要是因为dataframewriter是sparksql对接外部数据源写入的入口携带类,下面这些内容是给dataframewriter注册的携带信息

教你如何让spark sql写mysql的时候支持update操作

然后在出发save()操作后,就开始将数据写入;

接下来看save()源码:

教你如何让spark sql写mysql的时候支持update操作

在上面的源码里面主要是注册datasource实例,然后使用datasource的write方法进行数据写入

实例化datasource的时候:

def save(): unit = {
    assertnotbucketed("save")
    val datasource = datasource(
      df.sparksession,
      classname = source,//自定义数据源的包路径
      partitioncolumns = partitioningcolumns.getorelse(nil),//分区字段
      bucketspec = getbucketspec,//分桶(用于hive)
      options = extraoptions.tomap)//传入的注册信息
    //mode:插入数据方式savemode , df:要插入的数据
    datasource.write(mode, df)
  }

然后就是datasource.write(mode, df)的细节,整段的逻辑就是:

根据providingclass.newinstance()去做模式匹配,然后匹配到哪里,就执行哪里的代码;

教你如何让spark sql写mysql的时候支持update操作

然后看下providingclass是什么:

教你如何让spark sql写mysql的时候支持update操作

教你如何让spark sql写mysql的时候支持update操作

拿到包路径.defaultsource之后,程序进入:

教你如何让spark sql写mysql的时候支持update操作

那么如果是数据库作为写入目标的话,就会走:datasource.createrelation,直接跟进源码:

教你如何让spark sql写mysql的时候支持update操作

很明显是个特质,因此哪里实现了特质,程序就会走到哪里了;

实现这个特质的地方就是:包路径.defaultsource , 然后就在这里面去实现数据的插入和update的支持操作;

4、改造源码

根据代码的流程,最终sparksql 将数据写入mysql的操作,会进入:包路径.defaultsource这个类里面;

也就是说,在这个类里面既要支持spark的正常插入操作(savemode),还要在支持update;

如果让sparksql支持update操作,最关键的就是做一个判断,比如:

if(isupdate){
    sql语句:insert into student (columns_1,columns_2)values ('第一个字段值','第二个字段值') on duplicate key update columns_1 = '呵呵哒',columns_2 = '哈哈哒';
}else{
    insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)
}

但是,在spark生产sql语句的源码中,是这样写的:

教你如何让spark sql写mysql的时候支持update操作

没有任何的判断逻辑,就是最后生成一个:

insert into table (字段1 , 字段2....) values (? , ? ...)

所以首要的任务就是 ,怎么能让当前代码支持:on duplicate key update

可以做个大胆的设计,就是在insertstatement这个方法中做个如下的判断

def insertstatement(conn: connection, savemode:customsavemode , table: string, rddschema: structtype, dialect: jdbcdialect)
      : preparedstatement = {
    val columns = rddschema.fields.map(x => dialect.quoteidentifier(x.name)).mkstring(",")
    val placeholders = rddschema.fields.map(_ => "?").mkstring(",")
    if(savemode == customsavemode.update){
        //todo 如果是update,就组装成on duplicate key update的模式处理
        s"insert into $table ($columns) values ($placeholders) on duplicate key update $duplicatesetting"
    }esle{
        val sql = s"insert into $table ($columns) values ($placeholders)"
        conn.preparestatement(sql)
    }
    
  }

这样,在用户传递进来的savemode模式,我们进行校验,如果是update操作,就返回对应的sql语句!

所以按照上面的逻辑,我们代码这样写:

教你如何让spark sql写mysql的时候支持update操作

这样我们就拿到了对应的sql语句;

但是只有这个sql语句还是不行的,因为在spark中会执行jdbc的preparestatement操作,这里面会涉及到游标。

即jdbc在遍历这个sql的时候,源码会这样做:

教你如何让spark sql写mysql的时候支持update操作

看下makesetter:

教你如何让spark sql写mysql的时候支持update操作

所谓有坑就是:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?)

那么当前在源码中返回的数组长度应该是3:

val setters: array[jdbcvaluesetter] = rddschema.fields.map(_.datatype)
        .map(makesetter(conn, dialect, _)).toarray

但是如果我们此时支持了update操作,既:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?) on duplicate key update 字段1 = ?,字段2 = ?,字段3=?;

那么很明显,上面的sql语句提供了6个? , 但在规定字段长度的时候只有3

教你如何让spark sql写mysql的时候支持update操作

这样的话,后面的update操作就无法执行,程序报错!

所以我们需要有一个 识别机制,既:

if(isupdate){
    val numfields = rddschema.fields.length * 2
}else{
    val numfields = rddschema.fields.length
}

教你如何让spark sql写mysql的时候支持update操作

row[1,2,3] setter(0,1) //index of setter , index of row setter(1,2) setter(2,3) setter(3,1) setter(4,2) setter(5,3)

所以在preparestatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑

因此,代码改造前样子:

教你如何让spark sql写mysql的时候支持update操作

教你如何让spark sql写mysql的时候支持update操作

改造后的样子:

try {
      if (supportstransactions) {
        conn.setautocommit(false) // everything in the same db transaction.
        conn.settransactionisolation(finalisolationlevel)
      }
//      val stmt = insertstatement(conn, table, rddschema, dialect)
      //此处采用最新自己的sql语句,封装成preparestatement
      val stmt = conn.preparestatement(sqlstmt)
      println(sqlstmt)
      /**
        * 在mysql中有这样的操作:
        * insert into user_admin_t (_id,password) values ('1','第一次插入的密码')
        * insert into user_admin_t (_id,password)values ('1','第一次插入的密码') on duplicate key update _id = 'upid',password = 'uppassword';
        * 如果是下面的on duplicate key操作,那么在preparestatement中的游标会扩增一倍
        * 并且如果没有update操作,那么他的游标是从0开始计数的
        * 如果是update操作,要算上之前的insert操作
        * */
        //makesetter也要适配update操作,即游标问题
​
      val isupdate = savemode == customsavemode.update
      val setters: array[jdbcvaluesetter] = isupdate match {
        case true =>
          val setters: array[jdbcvaluesetter] = rddschema.fields.map(_.datatype)
            .map(makesetter(conn, dialect, _)).toarray
          array.fill(2)(setters).flatten
        case _ =>
          rddschema.fields.map(_.datatype)
      val numfieldslength = rddschema.fields.length
      val numfields = isupdate match{
        case true => numfieldslength *2
        case _ => numfieldslength
      val cursorbegin = numfields / 2
      try {
        var rowcount = 0
        while (iterator.hasnext) {
          val row = iterator.next()
          var i = 0
          while (i < numfields) {
            if(isupdate){
              //需要判断当前游标是否走到了on duplicate key update
              i < cursorbegin match{
                  //说明还没走到update阶段
                case true =>
                  //row.isnullat 判空,则设置空值
                  if (row.isnullat(i)) {
                    stmt.setnull(i + 1, nulltypes(i))
                  } else {
                    setters(i).apply(stmt, row, i, 0)
                  }
                  //说明走到了update阶段
                case false =>
                  if (row.isnullat(i - cursorbegin)) {
                    //pos - offset
                    stmt.setnull(i + 1, nulltypes(i - cursorbegin))
                    setters(i).apply(stmt, row, i, cursorbegin)
              }
            }else{
              if (row.isnullat(i)) {
                stmt.setnull(i + 1, nulltypes(i))
              } else {
                setters(i).apply(stmt, row, i ,0)
            }
            //滚动游标
            i = i + 1
          }
          stmt.addbatch()
          rowcount += 1
          if (rowcount % batchsize == 0) {
            stmt.executebatch()
            rowcount = 0
        }
        if (rowcount > 0) {
          stmt.executebatch()
      } finally {
        stmt.close()
        conn.commit()
      committed = true
      iterator.empty
    } catch {
      case e: sqlexception =>
        val cause = e.getnextexception
        if (cause != null && e.getcause != cause) {
          if (e.getcause == null) {
            e.initcause(cause)
          } else {
            e.addsuppressed(cause)
        throw e
    } finally {
      if (!committed) {
        // the stage must fail.  we got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportstransactions) {
          conn.rollback()
        conn.close()
      } else {
        // the stage must succeed.  we cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: exception => logwarning("transaction succeeded, but closing failed", e)
// a `jdbcvaluesetter` is responsible for setting a value from `row` into a field for
  // `preparedstatement`. the last argument `int` means the index for the value to be set
  // in the sql statement and also used for the value in `row`.
  //preparedstatement, row, position , cursor
  private type jdbcvaluesetter = (preparedstatement, row, int , int) => unit
​
  private def makesetter(
      conn: connection,
      dialect: jdbcdialect,
      datatype: datatype): jdbcvaluesetter = datatype match {
    case integertype =>
      (stmt: preparedstatement, row: row, pos: int,cursor:int) =>
        stmt.setint(pos + 1, row.getint(pos - cursor))
    case longtype =>
        stmt.setlong(pos + 1, row.getlong(pos - cursor))
    case doubletype =>
        stmt.setdouble(pos + 1, row.getdouble(pos - cursor))
    case floattype =>
        stmt.setfloat(pos + 1, row.getfloat(pos - cursor))
    case shorttype =>
        stmt.setint(pos + 1, row.getshort(pos - cursor))
    case bytetype =>
        stmt.setint(pos + 1, row.getbyte(pos - cursor))
    case booleantype =>
        stmt.setboolean(pos + 1, row.getboolean(pos - cursor))
    case stringtype =>
//        println(row.getstring(pos))
        stmt.setstring(pos + 1, row.getstring(pos - cursor))
    case binarytype =>
        stmt.setbytes(pos + 1, row.getas[array[byte]](pos - cursor))
    case timestamptype =>
        stmt.settimestamp(pos + 1, row.getas[java.sql.timestamp](pos - cursor))
    case datetype =>
        stmt.setdate(pos + 1, row.getas[java.sql.date](pos - cursor))
    case t: decimaltype =>
        stmt.setbigdecimal(pos + 1, row.getdecimal(pos - cursor))
    case arraytype(et, _) =>
      // remove type length parameters from end of type name
      val typename = getjdbctype(et, dialect).databasetypedefinition
        .tolowercase.split("\\(")(0)
        val array = conn.createarrayof(
          typename,
          row.getseq[anyref](pos - cursor).toarray)
        stmt.setarray(pos + 1, array)
    case _ =>
      (_: preparedstatement, _: row, pos: int,cursor:int) =>
        throw new illegalargumentexception(
          s"can't translate non-null value for field $pos")
  }

完整代码:

https://github.com/niutaofan/bazinga

到此这篇关于教你如何让spark sql写mysql的时候支持update操作的文章就介绍到这了,更多相关spark sql写mysql支持update内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!