欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Spark 从 0 到 1 学习(7) —— Spark SQL

程序员文章站 2022-04-17 07:53:18
...

1. Shark

Shark 是基于 Spark 计算框架之上且兼容 Hive 语法的 SQL 执行引擎。由于底层的计算采用了 Spark,性能比 MapReduce 的 Hive 普遍快 2 倍以上。当数据全部 load 到内存的话,将快 10 倍以上。因此,Shark 可以作为交互式查询应用服务来使用。除了基于 Spark的特性外,Shark 是完全兼容 Hive 的语法,表结构以及 UDF 函数等。已有的 HiveSql 可以直接进行迁移到 Shark上。Shark 底层依赖于 Hive 的解析器、查询优化器。但正是由于 Shark 的整体设计架构对 Hive 的依赖性太强,难以支持其长远发展,比如不能和 Spark 的其他组件进行很好的集成,无法满足 Spark 的一栈式解决大数据处理的需求。

2. SparkSQL

2.1 SparkSQL 介绍

Hive 是 Shark 的前身, Shark 是 SparkSQL 的前身。SparkSQL 产生的根本原因是其完全脱离了 Hive 的限制。

  • SparkSQL 支持查询原生的 RDD。RDD 是 Spark 平台的核心概念,是 Spark 能够高效的处理大数据的各种场景的基础。
  • 能够在代码中写 SQL 语句。支持简单的 SQL 语法查询。
  • 可以在代码中写 Hive 语句访问 Hive 中的数据,并将结果取回作为 RDD 使用。

2.2 Spark on Hive 和 Hive on Spark

  • Spark on Hive:Hive 只作为存储角色,Spark 负责 sql 解析、优化、执行。
  • Hive on Spark:Hive 既作为存储,又负责 sql 的解析、优化, Spark 负责执行。

2.3 Dataset

Spark 从 0 到 1 学习(7) —— Spark SQL

Dataset 也是一个分布式数据容器。与 RDD 类似,然而 Dataset 更像传统数据库的二维表格。除了数据以外,还掌握数据的结构信息,即 schema。同时,与 Hive 类似,Dataset 也支持嵌套数据类型 (struct、array 和 map)。从 API 易用性的角度上看, Dataset API 提供的是一套高层的关系操作,比函数式的 RDD API 更加友好,门槛更低。

Dataset 的底层封装的是 RDD,只不过 RDD 的泛型是 Row 类型。

2.4 SparkSQL 的数据源

SparkSQL 的数据源可以是 JSON 类型的字符串、jdbc、Hive、HDFS 等。

2.5 SparkSQL 底层架构

首先拿到 sql 后解析一批未被解决的逻辑计划,再经过分析得到分析后的逻辑计划,经过一批优化规则转换成一批最佳优化的逻辑计划,再经过 SparkPlanner 的策略转化成一批物理计划,随后经过消费模型转换成一个个的 Spark 任务执行。

2.6 谓词下推(Predicate pushdown)

Spark 从 0 到 1 学习(7) —— Spark SQL

谓词下推是优化关系 SQL 查询的一项基本技术,将外层查询块的 WHERE 子句中的谓词移入所包含的较低层查询块(例如视图),从而能够提早进行数据过滤以及有可能更好地利用索引。

基本策略是,始终将过滤表达式尽可能移至靠近数据源的位置。

如图-左显示:关系表先进行关联查询再过滤出结果,这种查询效果比较慢。

如图-右显示:关系表先通过条件过滤出想要的结果,然后关联出最终的结果。这就是使用了谓词下推优化技术。

3. 创建 Dataset 的几种方式

3.1 读取 json 格式的文件创建

注意:

  • json 文件中的 json 数据不能嵌套 json 格式的数据。
  • Dataset 是一个个 Row 类型的 RDD,通过ds.rdd()/ds.javaRdd()转换成RDD。
  • 可以两种方式读取 json 格式的文件。
  • ds.show()默认显示前 20 行数据。
  • Dataset 原生 API 可以操作 Dataset (不方便)。
  • 注册成临时表时,表中的列默认按 ascii顺序显示列。

测试数据:

{    "id": 1,    "name": "张三",    "age": 23  }
{    "id": 2,    "name": "李四",    "age": 26  }
{    "id": 3,    "name": "王五",    "age": 20  }
{    "id": 4,    "name": "王璐",    "age": 27  }

java 方式:

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class SparkJson {

    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setMaster("local").setAppName("SparkJson");
        SparkContext sc = new SparkContext(conf);

        // 创建 SQLContext
        SQLContext sqlContext = new SQLContext(sc);

        // 读取 json文件
        String path = "/stu.txt";
        //Dataset<Row> ds = sqlContext.read().format("json").load(path);
        Dataset<Row> ds = sqlContext.read().json(path);
        // 显示schema信息
        ds.printSchema();
        // 显示数据
        ds.show();
        // Dataset 自带的 API
        ds.select(df.col("name"),df.col("age").plus(10).alias("addage")).show();
        ds.select(df.col("name"),df.col("age")).where(df.col("age").gt(25)).show();

        ds.groupBy(df.col("age")).count().show();

        // 注册临时表
        ds.createTempView("student");

        String sql = "select age, count(1) from student group by age";
        sqlContext.sql(sql).show();

    }
}

scala:

import org.apache.spark.sql.SparkSession

object SparkJson {

  def main(args: Array[String]): Unit = {
    val builder = SparkSession.builder().appName("SparkJson").master("local").getOrCreate()
    val sqlContext = builder.sqlContext
    val path = "/stu.txt"
    val ds = sqlContext.read.format("json").load(path)
//    val ds = sqlContext.read.json(path)

    // 输出 schame
    ds.printSchema()
    // 输出数据
    ds.show()

    // 输出需要的字段
    ds.select(ds.col("name"),ds.col("age")).show()

    // 条件查询
    ds.select(ds.col("name"),ds.col("age")).where(ds.col("age").>(25)).show()

    //  分组聚合
//    ds.groupBy(ds.col("age")).count().show()

    // 创建视图
    ds.createTempView("student")
    val sql = "select age, count(1) from student group by age"
    sqlContext.sql(sql).show()
  }

}

3.2 非 json 格式的 RDD 转 Dataset

3.2.1 通过反射的方式将 非 json 格式的 RDD 转换成 Dataset (不建议使用)

使用反射的方式转换的时候需要注意:

  • 自定义类需要序列化
  • 自定义类的访问级别是 public
  • RDD 转成 Dataset 后会根据自动按 assic码排序
  • 将 Dataset 转换成 RDD 时获取字段有两种方式:
    1. ds.getInt(0):通过下标获取(不推荐使用)。
    2. ds.getAs("列名"):通过列明获取 (推荐使用)。

测试数据:

1,张三,23
2,李四,26
3,王五,26
4,王璐,27

java:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;

import java.io.Serializable;

public class SparkText {

    public static void main(String[] args) throws Exception{
        SparkSession session = SparkSession.builder().appName("SparkText").master("local").getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(session.sparkContext());
        SQLContext sqlContext = session.sqlContext();
        String path = "/student.txt";
        JavaRDD<Student> rdd = jsc.textFile(path).map(line -> {
            Student stu = new Student();
            String[] ss = line.split(",");
            stu.id = Integer.parseInt(ss[0]);
            stu.name = ss[1];
            stu.age = Integer.parseInt(ss[2]);
            return stu;
        });
        /**
        * 传入进去Person.class的时候,sqlContext是通过反射的方式创建Dataset
        * 在底层通过反射的方式获得Person的所有field,结合RDD本身,就生成了Dataset
        */
        Dataset<Row> ds = sqlContext.createDataFrame(rdd, Student.class);
    }

    public static class Student implements Serializable {
        public int id;
        public String name;
        public int age;

        public int getId() {
            return id;
        }

        public void setId(int id) {
            this.id = id;
        }

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public int getAge() {
            return age;
        }

        public void setAge(int age) {
            this.age = age;
        }
    }
}

scala:

import org.apache.spark.sql.SparkSession


object SparkText {

  case class Student(id:Int,name:String,age:Int)

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().master("local").appName("SparkText").getOrCreate()
    val sc = session.sparkContext
    val sqlContext = session.sqlContext
    val path = "/student.txt"
    // 将RDD隐式转换成DataSet
    import sqlContext.implicits._
    val rdd = sc.textFile(path).map{line => {
      val ss = line.split(",")
      val id = ss(0).toInt
      val name = ss(1)
      val age = ss(2).toInt
      val stu = Student(id, name, age)
      stu
    }}
    val ds = rdd.toDF()

  }
}

3.2.2 动态创建 Schema 将非 json 格式的 RDD 转换成 Dataset

java:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class SparkRow {

    public static void main(String[] args) throws Exception{
        SparkSession session = SparkSession.builder().master("local").appName("SparkRow").getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(session.sparkContext());
        SQLContext sqlContext = session.sqlContext();
        String path = "/student.txt";
        JavaRDD<Row> rdd = jsc.textFile(path).map(line -> {
            String[] ss = line.split(",");
            int id = Integer.parseInt(ss[0]);
            String name = ss[1];
            int age = Integer.parseInt(ss[2]);
            return RowFactory.create(id, name, age);
        });

        /**
         *  创建 schema
         */
        List<StructField> fields = new ArrayList<>(3);
        fields.add(DataTypes.createStructField("id",DataTypes.IntegerType,true));
        fields.add(DataTypes.createStructField("name",DataTypes.StringType,true));
        fields.add(DataTypes.createStructField("age",DataTypes.IntegerType,true));
        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row>  df = sqlContext.createDataFrame(rdd, schema);
        // 显示schema信息
        df.printSchema();
        // 显示数据
        df.show();

    }
}

scala:

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{RowFactory, SparkSession}

object SparkRow {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().master("local").appName("SparkRow").getOrCreate()
    val sc = session.sparkContext
    val sqlContext = session.sqlContext

    val path = "/student.txt"
    val rdd = sc.textFile(path).map(line => {
      val ss = line.split(",")
      val id = Integer.valueOf(ss(0))
      val name = ss(1)
      val age = Integer.valueOf(ss(2))
      val row = RowFactory.create(id, name, age)
      row
    })
    val fields= List(
      StructField("id",IntegerType,true),
      StructField("name",StringType,true),
      StructField("age",IntegerType,true)
    )

    val schema = StructType(fields)
    val ds = sqlContext.createDataFrame(rdd,schema)
    // 输出 schame
    ds.printSchema()
    // 输出数据
    ds.show()
    sc.stop()
  }
}

3.3 读取 JDBC 中的数据创建 Dataset(MySql 为例)

java:

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;

import java.util.Properties;

public class SparkJDBC {

    public static void main(String[] args) throws Exception{
        SparkSession session = SparkSession.builder().master("local").appName("SparkJDBC").getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(session.sparkContext());
        SQLContext sqlContext = session.sqlContext();
        Properties options = new Properties();
        String url = "jdbc:mysql://127.0.0.1:3306/student?useUnicode=true&characterEncoding=utf8&allowMultiQueries=true";
        String driver = "com.mysql.jdbc.Driver";
        String user = "root";
        String password = "123456";
        String dbtable = "student";
        options.put("driver", driver);
        options.put("user",user);
        options.put("password",password);
        Dataset<Row> ds = sqlContext.read().jdbc(url, dbtable, options);
        ds.createTempView("student");
        String sqlText = "select * from student limit 30";
        sqlContext.sql(sqlText).show();
        jsc.stop();
    }
}

scala:

import java.util.Properties

import org.apache.spark.sql.SparkSession

object SparkJDBC {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().master("local").appName("SparkJDBC").getOrCreate()
    val sqlContext = session.sqlContext
    val options = new Properties()
    val url = "jdbc:mysql://127.0.0.1:3306/student?useUnicode=true&characterEncoding=utf8&allowMultiQueries=true"
    val driver = "com.mysql.jdbc.Driver"
    val user = "root"
    val password = "123456"
    val dbtable = "student"
    options.put("driver", driver)
    options.put("user", user)
    options.put("password", password)
    val ds = sqlContext.read.jdbc(url,dbtable,options)
    ds.createTempView("student")
    val sqlText = "select * from student limit 30"
    sqlContext.sql(sqlText).show()
    session.stop()
  }

}

3.4 读取 Hive 中的数据创建 Dataset

  • HiveContextSQLContext的子类,连接 Hive 建议使用 HiveContext

  • 由于本地没有 Hive 环境,运行时需要提交到集群运行。

    ./spark-submit
    --master spark://node1:7077,node2:7077
    --executor-cores 1
    --executor-memory 2G
    --total-executor-cores 1
    --class com.spark.sql.hive.SparkHive
    /root/spark/job/SparkHiveTest.jar
    

java:

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class SparkHive {

    public static void main(String[] args) {
        SparkSession session = SparkSession.builder().appName("SparkHive").enableHiveSupport().getOrCreate();
        // 指定 数据库
        session.sql("USE spark");
        // 删除表
        session.sql("DROP TABLE IF EXISTS student");
        // 创建表
        session.sql("CREATE TABLE IF NOT EXISTS student (id INT,name STRING,age INT) " +
                "row format delimited fields terminated by '\t' ");
        // 加载本地数据到hive
        String dataPath = "/student.txt";
        session.sql("load data local inpath "+dataPath+" into table student");

        // 查询hive数据
        Dataset<Row> ds = session.sql("select * from student");
        ds.show();
        session.stop();
    }

}

scala:

import org.apache.spark.sql.SparkSession

object SparkHive {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().appName("SparkHive").enableHiveSupport().getOrCreate()
    // 指定 数据库
    session.sql("USE spark")
    // 删除表
    session.sql("DROP TABLE IF EXISTS student")
    // 创建表
    session.sql("CREATE TABLE IF NOT EXISTS student (id INT,name STRING,age INT) " + "row format delimited fields terminated by '\t' ")
    // 加载本地数据到hive
    val dataPath = "/student.txt"
    session.sql("load data local inpath " + dataPath + " into table student")

    // 查询hive数据
    val ds = session.sql("select * from student")
    ds.show()
    session.stop()
  }

}

4. Spark On Hive 的配置

  1. 修改spark的 hive-site.xml 文件

    在 spark/conf 中创建 hive-site.xml,配置 hive 的 metastore 路径

    <configuration>
       <property>
            <name>hive.metastore.uris</name>
            <value>thrift://node1:9083</value>
       </property>
    </configuration>
    
  2. 启动 Hive 的 metastore 服务

    hive --service metastore
    
  3. 启动 Zookeeper 集群,启动 HDFS 集群

  4. 验证是否启动成功

    这里通过 SparkShell 操作 Hive 来验证配置是否成功。

    ./spark-shell
    --master spark://node:7077,node2:7077
    --executor-cores 1
    --executor-memory 1g
    --total-executor-cores 1
    import org.apache.spark.sql.SparkSession
    val session = SparkSession.builder().appName("SparkHive").enableHiveSupport().getOrCreate()
    session.sql("show databases").show
    

    注意:

    如果使用 Spark on Hive 查询数据时,出现错误:

    Cased by: java.net.UnknownHostExeception: XXX
    

    找不到 HDFS 集群路径,要在客户端集群 conf/spark-env.sh 中设置 HDFS 的路径:

    export HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop
    

5. 存储 Dataset

  1. 将 Dataset 存储为 parquet 文件。

    ds.write().mode(SaveMode.Overwrite).format("parquet").save("/spark/job/result/student");
    
    ds.write.mode(saveMode = SaveMode.Overwrite).format("parquet").save("/spark/job/result/student")
    
  2. 将 Dataset 存储到 JDBC 数据库。

    ds.write().mode(SaveMode.Append).jdbc(url,"result",options);
    
    data.write.mode(saveMode = SaveMode.Append).jdbc(url,table = "result",options)
    
  3. 将 Dataset 存储到 Hive表。

6. 自定义函数 UDF 和 UDAF

6.1 UDF:用户自定义函数

自定义函数可以实现 UDFX 接口。

6.1.1 java 实现

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class SparkSqlUDF {
    public static void main(String[] args) throws Exception{
        SparkSession session = SparkSession.builder().master("local").appName("SparkSqlUDF").getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(session.sparkContext());
        SQLContext sqlContext = session.sqlContext();
        String path = "/student.txt";
        JavaRDD<Row> rdd = jsc.textFile(path).map(line -> {
            String[] ss = line.split(",");
            int id = Integer.parseInt(ss[0]);
            String name = ss[1];
            int age = Integer.parseInt(ss[2]);
            return RowFactory.create(id, name, age);
        });


        /**
         *  创建 schema
         */
        List<StructField> fields = new ArrayList<>(3);
        fields.add(DataTypes.createStructField("id",DataTypes.IntegerType,true));
        fields.add(DataTypes.createStructField("name",DataTypes.StringType,true));
        fields.add(DataTypes.createStructField("age",DataTypes.IntegerType,true));
        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> df = sqlContext.createDataFrame(rdd, schema);
        df.createTempView("student");
        sqlContext.udf().register("StrLen", new UDF1<String, Integer>() {
            @Override
            public Integer call(String s) throws Exception {
                return s.length();
            }
        }, DataTypes.IntegerType);

        sqlContext.sql("select name, StrLen(name) as length from student").show();

        jsc.stop();
    }
}

6.1.2 scala 实现

import org.apache.spark.sql.{RowFactory, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

object SparkSqlUDF {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().master("local").appName("SparkRow").getOrCreate()
    val sc = session.sparkContext
    val sqlContext = session.sqlContext

    val path = "/student.txt"
    val rdd = sc.textFile(path).map(line => {
      val ss = line.split(",")
      val id = Integer.valueOf(ss(0))
      val name = ss(1)
      val age = Integer.valueOf(ss(2))
      val row = RowFactory.create(id, name, age)
      row
    })
    val fields= List(
      StructField("id",IntegerType,true),
      StructField("name",StringType,true),
      StructField("age",IntegerType,true)
    )

    val schema = StructType(fields)
    val ds = sqlContext.createDataFrame(rdd,schema)

    // 创建视图
    ds.createTempView("student")
	// 自定义函数
    sqlContext.udf.register("StrLen",(s:String,i:Int) =>(s.length+i))
    val sql = "select name, StrLen(name,1) as length from student"
    sqlContext.sql(sql).show()
    sc.stop()
  }

}

6.2 UDAF:用户自定义聚合函数

实现 UDAF 函数需要继承 UserDefinedAggregateFunction

6.2.1 java实现

自定义UDAF

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

import java.util.Arrays;

public class UDAFUserDefinedAgg extends UserDefinedAggregateFunction {

    /**
     * 指定输入字段的字段及类型
     */
    @Override
    public StructType inputSchema() {
        StructType name = DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true)));
        return name;
    }

    /**
     * 在进行聚合操作的时候所要处理的数据的结果的类型
     */
    @Override
    public StructType bufferSchema() {
        StructType bf = DataTypes.createStructType(
                Arrays.asList(DataTypes.createStructField("bf", DataTypes.IntegerType,
                        true)));
        return bf;
    }

    /**
     * 指定UDAF函数计算后返回的结果类型
     */
    @Override
    public DataType dataType() {
        return DataTypes.IntegerType;
    }

    @Override
    public boolean deterministic() {
        // 设置为 true
        return true;
    }

    /**
     * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0);
    }

    /**
     * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
     * buffer.getInt(0)获取的是上一次聚合后的值
     * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
     * 大聚和发生在reduce端.
     * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0,buffer.getInt(0)+1);
    }

    /**
     * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     * buffer1.getInt(0) : 大聚和的时候 上一次聚合后的值
     * buffer2.getInt(0) : 这次计算传入进来的update的结果
     * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0,buffer1.getInt(0)+buffer2.getInt(0));
    }

    /**
     * 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果
     */
    @Override
    public Object evaluate(Row buffer) {
        return buffer.getInt(0);
    }
}

spark 程序:

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class SparkUDAF {

    public static void main(String[] args) throws Exception{
        SparkSession session = SparkSession.builder().master("local").appName("SparkUDAF").getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(session.sparkContext());
        SQLContext sqlContext = session.sqlContext();
        String path = "/student.txt";
        JavaRDD<Row> rdd = jsc.textFile(path).map(line -> {
            String[] ss = line.split(",");
            int id = Integer.parseInt(ss[0]);
            String name = ss[1];
            int age = Integer.parseInt(ss[2]);
            return RowFactory.create(id, name, age);
        });


        /**
         *  创建 schema
         */
        List<StructField> fields = new ArrayList<>(3);
        fields.add(DataTypes.createStructField("id",DataTypes.IntegerType,true));
        fields.add(DataTypes.createStructField("name",DataTypes.StringType,true));
        fields.add(DataTypes.createStructField("age",DataTypes.IntegerType,true));
        StructType schema = DataTypes.createStructType(fields);
        Dataset<Row> df = sqlContext.createDataFrame(rdd, schema);
        UDAFUserDefinedAgg stringCount = new UDAFUserDefinedAgg();
        sqlContext.udf().register("StringCount",stringCount);
         // 注册临时表
        df.createTempView("student");

        String sql = "select name, StringCount(name) as num from student group by name";
        sqlContext.sql(sql).show();

    }
}

6.2.2 scala 实现

自定义UDAF:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StringType, StructType}

class UDAFUserDefinedAgg extends UserDefinedAggregateFunction{
  /**
    * 输入数据的类型
    * @return
    */
  override def inputSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("input",StringType,true)))
  }

  /**
    * 聚合操作时,所处理的数据的类型
    * @return
    */
  override def bufferSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("bf",IntegerType,true)))
  }

  /**
    * 最终函数返回值的类型
    * @return
    */
  override def dataType: DataType = {
    DataTypes.IntegerType
  }

  override def deterministic: Boolean = true

  /**
    * 为每个分组的数据执行初始化值
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0;
  }

  /**
    * 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
    * @param buffer
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  /**
    * 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getAs[Int](0) + buffer2.getAs[Int](0))
  }

  /**
    * 最后返回一个最终的聚合值,要和dataType的类型一一对应
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }
}

spark 程序:

import org.apache.spark.sql.{RowFactory, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

object SparkUDAF {

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder().master("local").appName("SparkRow").getOrCreate()
    val sc = session.sparkContext
    val sqlContext = session.sqlContext

    val path = "/student.txt"
    val rdd = sc.textFile(path).map(line => {
      val ss = line.split(",")
      val id = Integer.valueOf(ss(0))
      val name = ss(1)
      val age = Integer.valueOf(ss(2))
      val row = RowFactory.create(id, name, age)
      row
    })
    val fields= List(
      StructField("id",IntegerType,true),
      StructField("name",StringType,true),
      StructField("age",IntegerType,true)
    )

    val schema = StructType(fields)
    val ds = sqlContext.createDataFrame(rdd,schema)

    sqlContext.udf.register("StringCount", new UDAFUserDefinedAgg)

    // 创建视图
    ds.createTempView("student")
    val sql = "select name, StringCount(name) from student group by name"
    sqlContext.sql(sql).show()
    sc.stop()
  }
}
相关标签: 大数据 spark