Spark SQL 中的UDF、UDAF、UDTF
程序员文章站
2022-04-29 08:07:37
...
UDF
UDF(User-defined functions)用户自定义函数,简单说就是输入一行输出一行的自定义算子。(一对一)
数据文件:hobbies.txt,第一列为姓名,其他为兴趣爱好
alice,jogging&Coding&cooking
lina,traveldance&cooking
自定义UDF,实现的是计算每个人的兴趣爱好个数
// 样例类
case class Hobbies(name:String,hobbies:String)
object UDFDemo {
def main(args: Array[String]): Unit = {
// 获取SparkSession对象
val spark: SparkSession = SparkSession.builder()
.appName("udfDemo")
.master("local[*]")
.getOrCreate()
// 获取SparkContext对象
val sc:SparkContext = spark.sparkContext
import spark.implicits._
// 读取文件
val rdd1: RDD[String] = sc.textFile("in/hobbies.txt")
// 将姓名与爱好以逗号分隔,创建成样例类后转成DataFrame
val df: DataFrame = rdd1.map(_.split(","))
.map(x => Hobbies(x(0), x(1)))
.toDF()
// 注册临时表
df.registerTempTable("hobbies")
// 注册udf,名字为hoby_num,功能用匿名函数代替
spark.udf.register("hoby_num",
(x:String)=>x.split("&").length)
// 查询
val frame: DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) from hobbies")
frame.show()
}
}
UDAF
UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。(多对一)
数据文件:user.json
{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
自定UDAF,实现计算平均年龄
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("udafDemo")
.getOrCreate()
val df: DataFrame = spark.read.json("in/user.json")
spark.udf.register("ageAvg",new MyAgeAvgFunction)
// 创建临时视图
df.createTempView("userInfo")
// 查询
val frame: DataFrame = spark.sql("select sex,ageAvg(age) from userInfo group by sex")
frame.printSchema()
frame.show()
}
}
// AgeAvgFunction继承UserDefinedAggregateFunction,需要重写8个方法
class AgeAvgFunction extends UserDefinedAggregateFunction{
// 指定聚合函数的输入数据类型
override def inputSchema: StructType = {
// age为要聚合的列,LongType为类型
new StructType().add("age",LongType)
//也可以写成下面这种形式
// StructType(StructField("age",LongType)::Nil)
}
// 指定缓冲区的数据结构
override def bufferSchema: StructType = {
// 此处可以这样理解:缓冲区会保存两个数据,sum是用来记录年龄总和,count是用来记录总人数
new StructType().add("sum",LongType).add("count",LongType)
// StructType(StructField("sum",LongType)::StructField("count",LongType)::Nil)
}
// 指定集合函数输出数据的类型
override def dataType: DataType = DoubleType
// 聚合函数是否是幂相等,即相同输入数据是否总能得到相同输出数据
override def deterministic: Boolean = true
// 初始化缓冲区的初始值,可根据需要自行设定
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
// 可理解为单个buffer内部的计算,即一条数据传递到一个buffer内后,它需要把之前的年龄与此条数据年龄相加,然后数量加1
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
// 合并多个buffer计算的结果,类似不同分区结果合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果,总年龄/总人数
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
UDTF
UDTF(User-Defined Table-Generating Functions),用户自定义生成函数。它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”。(一对多)
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
// 创建SparkSession
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName("sqlDemo")
.enableHiveSupport() // //启用hive
.getOrCreate
val sc: SparkContext = spark.sparkContext
import spark.implicits._
val lines: RDD[String] = sc.textFile("in/udtf.txt")
// 将数据处理并转换成DataFrame
val stuDF: DataFrame = lines.map(_.split(","))
.filter(x => x(1).equals("ls"))
.map(x => (x(0), x(1), x(2)))
.toDF("id", "name", "class")
// 创建或替换临时视图
stuDF.createOrReplaceTempView("student")
// 这里需要注意,如果编写的UDTF类有包名,as 后面需要将表名写上
spark.sql("create temporary function myFunc as 'sql.myUDTF'")
// 在spark sql 中使用UDTF查询
val resultDF: DataFrame = spark.sql("select myFunc(class) from student")
// 查看结果
resultDF.printSchema()
resultDF.show()
}
}
// 继承GenericUDTF类
class myUDTF extends GenericUDTF{
// 该函数的作用:①输入参数校验,只能传递一个参数 ②指定输出的字段名和字段类型
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
// 只能有一个参数,若多于1个,则抛异常
if(argOIs.length!=1){
throw new UDFArgumentException("只能传递一个参数")
}
// 用于验证参数的类型
if(argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
//初始化表结构
//创建数组列表存储表字段
val fieldNames = new util.ArrayList[String]()
val fieldsOIs = new util.ArrayList[ObjectInspector]()
// 输出字段的名称
fieldNames.add("hobbies")
// 这里定义的是输出列字段类型
fieldsOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
//将表结构两部分聚合在一起
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldsOIs)
}
// 用于处理数据,入参数组objects里只有1行数据,即每次调用process方法只处理一行数据
override def process(objects: Array[AnyRef]): Unit = {
// 将字符串切分为单个字符的数组
val strings: Array[String] = objects(0).toString.split(" ")
for (elem <- strings) {
val tmp = new Array[String](1)
tmp(0) = elem
//forward必须传入字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}