欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

spark 自定义数据源

程序员文章站 2022-06-07 17:58:39
...
1、创建hbase数据源表
node1> bin/hbase shell

create 'spark_hbase_sql','cf'
put 'spark_hbase_sql','0001','cf:name','zhangsan'
put 'spark_hbase_sql','0001','cf:score','80'
put 'spark_hbase_sql','0002','cf:name','lisi'
put 'spark_hbase_sql','0002','cf:score','60'

2、创建Hbase的数据保存表
bin/hbase shell
create 'spark_hbase_write','cf'

3.代码 编写
package com.alibaba.programApp

import java.util
import java.util.Optional

import com.travel.utils.HbaseTools
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.{Connection, Put, Result, ResultScanner, Scan, Table}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader}
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType

/**
  * @author cherish
  * @create 2020-04-26 10:49
  */
object HBaseSourceAndSink {
    def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
        conf.setMaster("local[2]").setAppName("sparkSqlSourceAndSink")
        val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

        //format 需要我们自定义数据源
        val df: DataFrame = spark.read.format("com.travel.programApp.HBaseSource")
            .option("hbase.table.name", "spark_hbase_sql") //我们自己带的一些参数
            .option("cf.cc", "cf:name,cf:score") //定义我们查询hbase的那些列
            .option("schema", "`name` STRING , `score` STRING") //定义我们表的schema 返回的数据是按照循序定义的
            .load

        df.createOrReplaceTempView("sparkHBaseSQL")

        df.printSchema()

        //分析得到的结果数据 , 将结果数据保存到hbase , redis 或者 mysql 中 或者es
        val resultDF: DataFrame = spark.sql("select * from sparkHBaseSQL where score > 70 ")

        resultDF.write.format("com.travel.programApp.HBaseSource")
            .mode(SaveMode.Overwrite)
            .option("hbase.table.name" , "spark_hbase_write")
            .option("cf" , "cf")
            .save()


    }
}

//自定义数据源,实现数据的查询
class  HBaseSource extends DataSourceV2 with ReadSupport with WriteSupport{
    /**
      * 定义我们映射的表的schema
      * @param options
      * @return
      */
    override def createReader(options: DataSourceOptions): DataSourceReader = {
        //从spark.read.format().option() 里面传过来的
        val tableName  :String = options.get("hbase.table.name").get()
        val cfAndCC :String = options.get("cf.cc").get()
        val schema:String =  options.get("schema").get()


        new HBaseDataSourceReader(tableName , cfAndCC , schema)
    }

    override def createWriter(jobId: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = {

        //从spark.read.format().option() 里面传过来的
        val tableName  :String = options.get("hbase.table.name").get()
        val family :String = options.get("cf").get()


        Optional.of(new HBaseDataSourceWriter(tableName))
    }
}

class HBaseDataSourceWriter(tableName : String) extends DataSourceWriter {

    /**
      * 将数据保存起来全部依赖这个方法
      * @return
      */
    override def createWriterFactory(): DataWriterFactory[Row] = {
        new HBaseDataWriterFactory(tableName)
    }

    //数据提方法
    override def commit(messages: Array[WriterCommitMessage]): Unit = {

    }

    //放弃数据的插入方法
    override def abort(messages: Array[WriterCommitMessage]): Unit = {

    }
}

class HBaseDataWriterFactory(tableName : String) extends DataWriterFactory[Row] {
    override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
        new HBaseDataWriter(tableName)
    }
}

class HBaseDataWriter(tableName :String) extends DataWriter[Row] {

    private val conn = HbaseTools.getHbaseConn
    private val table: Table = conn.getTable(TableName.valueOf(tableName))

    //写入数据
    override def write(record: Row): Unit = {
        val name:String = record.getString(0)
        val score  :  String = record.getString(1)

        val put = new Put("0001".getBytes())

        put.addColumn("cf".getBytes() , "name".getBytes() ,  name.getBytes())
        put.addColumn("cf".getBytes() , "score".getBytes() , score.getBytes())

        table.put(put)

    }

    //数据的提价方法,数据插入完成之后,在这个方法里面进行数据的事务提交
    override def commit(): WriterCommitMessage = {
        //hbase 里面没有事务 , 所以在这里就把 table 和 conn 关闭 , 然后返回null
        table.close()
        conn.close()
        null
    }

    override def abort(): Unit = {

    }
}

class HBaseDataSourceReader(tableName:String , cfAndCC:String , schema:String) extends DataSourceReader {
    override def readSchema(): StructType = {
        StructType.fromDDL(schema)
    }

    override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = {
        /**
          * new HBaseDateReaderFactory().asInstanceOf[DataReaderFactory[Row] ]:
          * 将 HBaseDateReaderFactory() 转成 DataReaderFactory[Row]  对象
          * Seq[] scala集合
          * as.Java  将 scala seq 集合转成 java 集合
          */
        import scala.collection.JavaConverters._
        Seq(new HBaseDateReaderFactory(tableName , cfAndCC).asInstanceOf[DataReaderFactory[Row]]).asJava
    }
}
class HBaseDateReaderFactory(tableName:String,cfAndCC:String) extends DataReaderFactory[Row] {
    override def createDataReader(): DataReader[Row] = {
        new HBaseDataReader(tableName , cfAndCC);
    }
}

/**
  * 自定义HBaseDateReader 实现dateReader接口
  *
  */
class HBaseDataReader(tableName :String , cfAndCC:String) extends DataReader[Row] {

    var conn: Connection = null
    var table: Table = null

    var scan = new Scan()
    var resultScanner: ResultScanner = null

    /**
      * 使用ProtobufUtil将sparkContext对象序列化成为一个字符串传下来,下面在反序列化
      */

    /**
      * 获取我们hbase的数据就在这
      * @return
      */
    def getIterator: Iterator[Seq[AnyRef]] = {
        conn = HbaseTools.getHbaseConn
        table = conn.getTable(TableName.valueOf(tableName))
        resultScanner = table.getScanner(scan)

        import scala.collection.JavaConverters._

        val iterator: Iterator[Seq[AnyRef]] = resultScanner.iterator().asScala.map(eachResult => {

            /*val cfCCArr: Array[String] = cfAndCC.split(",")
            val family: String = cfCCArr(0).split(":")(0)

            val clumn1: String = cfCCArr(0).split(":")(1)
            val clumn2: String = cfCCArr(2).split(":")(1)*/

            val name: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "name".getBytes()))
            val score: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "score".getBytes()))
            System.out.println("===================================")
            System.out.println(Seq(name, score).toString())
            System.out.println("===================================")

            Seq(name, score)
        })
        iterator

    }

    val data:Iterator[Seq[AnyRef]] = getIterator

    /**
      * 这个方法反复不断的被调用,只要我们查询到了数据,就可以使用next方法一直获取下一条数据
      * @return
      */
    override def next(): Boolean = {
        data.hasNext
    }



    /**
      * 获取到的数据在这个方法里面一条条的解析,解析之后,映射到我们提前定义的表里面去
      * @return
      */
    override def get(): Row ={
        Row.fromSeq(data.next())
    }

    /**
      *
      */
    override def close(): Unit = {
        table.close()
        conn.close()
    }
}