欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Spark SQL 中的UDF、UDAF、UDTF

程序员文章站 2022-04-29 08:07:37
...

UDF

UDF(User-defined functions)用户自定义函数,简单说就是输入一行输出一行的自定义算子。(一对一)
数据文件:hobbies.txt,第一列为姓名,其他为兴趣爱好

alice,jogging&Coding&cooking
lina,traveldance&cooking

自定义UDF,实现的是计算每个人的兴趣爱好个数

// 样例类
case class Hobbies(name:String,hobbies:String)
object UDFDemo {
  def main(args: Array[String]): Unit = {
  	// 获取SparkSession对象
    val spark: SparkSession = SparkSession.builder()
      .appName("udfDemo")
      .master("local[*]")
      .getOrCreate()
	// 获取SparkContext对象
    val sc:SparkContext = spark.sparkContext
    import spark.implicits._
	// 读取文件
    val rdd1: RDD[String] = sc.textFile("in/hobbies.txt")
    // 将姓名与爱好以逗号分隔,创建成样例类后转成DataFrame
    val df: DataFrame = rdd1.map(_.split(","))
      .map(x => Hobbies(x(0), x(1)))
      .toDF()
	// 注册临时表
    df.registerTempTable("hobbies")
    // 注册udf,名字为hoby_num,功能用匿名函数代替
    spark.udf.register("hoby_num",
      (x:String)=>x.split("&").length)
	// 查询
    val frame: DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) from hobbies")
    frame.show()
  }
}

Spark SQL 中的UDF、UDAF、UDTF

UDAF

UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。(多对一)
数据文件:user.json

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

自定UDAF,实现计算平均年龄

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("udafDemo")
      .getOrCreate()

    val df: DataFrame = spark.read.json("in/user.json")
    spark.udf.register("ageAvg",new MyAgeAvgFunction)
	// 创建临时视图
    df.createTempView("userInfo")
	// 查询
    val frame: DataFrame = spark.sql("select sex,ageAvg(age) from userInfo group by sex")
    frame.printSchema()
    frame.show()
  }

}
// AgeAvgFunction继承UserDefinedAggregateFunction,需要重写8个方法
class AgeAvgFunction extends UserDefinedAggregateFunction{

  // 指定聚合函数的输入数据类型
  override def inputSchema: StructType = {
  	// age为要聚合的列,LongType为类型
    new StructType().add("age",LongType)
    //也可以写成下面这种形式
    // StructType(StructField("age",LongType)::Nil)
  }

  // 指定缓冲区的数据结构
  override def bufferSchema: StructType = {
  	// 此处可以这样理解:缓冲区会保存两个数据,sum是用来记录年龄总和,count是用来记录总人数
    new StructType().add("sum",LongType).add("count",LongType)
    //    StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
  }
  // 指定集合函数输出数据的类型
  override def dataType: DataType = DoubleType
  // 聚合函数是否是幂相等,即相同输入数据是否总能得到相同输出数据
  override def deterministic: Boolean = true
  // 初始化缓冲区的初始值,可根据需要自行设定
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0L
    buffer(1)=0L
  }
  // 可理解为单个buffer内部的计算,即一条数据传递到一个buffer内后,它需要把之前的年龄与此条数据年龄相加,然后数量加1
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }
  // 合并多个buffer计算的结果,类似不同分区结果合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  // 计算最终结果,总年龄/总人数
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

UDTF

UDTF(User-Defined Table-Generating Functions),用户自定义生成函数。它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”。(一对多)

object SparkUDTFDemo {
  def main(args: Array[String]): Unit = {
    // 创建SparkSession
    val spark: SparkSession = SparkSession.builder()
      .master("local[*]")
      .appName("sqlDemo")
      .enableHiveSupport()  // //启用hive
      .getOrCreate
    val sc: SparkContext = spark.sparkContext
    import spark.implicits._
    
    val lines: RDD[String] = sc.textFile("in/udtf.txt")
    // 将数据处理并转换成DataFrame
    val stuDF: DataFrame = lines.map(_.split(","))
      .filter(x => x(1).equals("ls"))
      .map(x => (x(0), x(1), x(2)))
      .toDF("id", "name", "class")
    // 创建或替换临时视图
    stuDF.createOrReplaceTempView("student")
    // 这里需要注意,如果编写的UDTF类有包名,as 后面需要将表名写上
    spark.sql("create temporary function myFunc as 'sql.myUDTF'")
    // 在spark sql 中使用UDTF查询
    val resultDF: DataFrame = spark.sql("select myFunc(class) from student")
    // 查看结果
    resultDF.printSchema()
    resultDF.show()
  }
}

// 继承GenericUDTF类
class myUDTF extends GenericUDTF{
  // 该函数的作用:①输入参数校验,只能传递一个参数 ②指定输出的字段名和字段类型
  override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
    // 只能有一个参数,若多于1个,则抛异常
    if(argOIs.length!=1){
      throw new UDFArgumentException("只能传递一个参数")
    }
    // 用于验证参数的类型
    if(argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
      throw new UDFArgumentException("参数类型不匹配")
    }
    
    //初始化表结构
    //创建数组列表存储表字段
    val fieldNames = new util.ArrayList[String]()
    val fieldsOIs = new util.ArrayList[ObjectInspector]()
    // 输出字段的名称
    fieldNames.add("hobbies")
    // 这里定义的是输出列字段类型
    fieldsOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
    //将表结构两部分聚合在一起
    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldsOIs)
  }

  // 用于处理数据,入参数组objects里只有1行数据,即每次调用process方法只处理一行数据
  override def process(objects: Array[AnyRef]): Unit = {
    // 将字符串切分为单个字符的数组
    val strings: Array[String] = objects(0).toString.split(" ")
    for (elem <- strings) {
      val tmp = new Array[String](1)
      tmp(0) = elem
      //forward必须传入字符串数组,即使只有一个元素
      forward(tmp)
    }
  }

  override def close(): Unit = {}
}