欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

二分类

程序员文章站 2022-05-26 19:09:38
...

数据集:泰坦尼克号的数据集

import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithSGD, SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.loss.SquaredError
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
  * 使用Spark MLlib中分类算法,针对泰坦尼克号数据集进行预测生存与否。
  */
object TitanicClassificationMLTest {

  def main(args: Array[String]): Unit = {

    // TODO: 构建SparkSession实例对象
    val spark = SparkSession.builder()
      .appName("TitanicClassificationMLTest")
      .master("local[4]")
      .getOrCreate()
    // 导入隐式转换
    import spark.implicits._

    // 获取SparkContext实例对象
    val sc = spark.sparkContext
    sc.setLogLevel("WARN")

    /**
      * TODO: a. 读取泰坦尼克号数据集
      */
    val titanicDF: DataFrame = spark.read
      .option("header", "true").option("inferSchema", "true")
      .csv("G:\\sparkmldata\\train.csv")
    // 样本数据
    titanicDF.show(10, truncate = false)
    titanicDF.printSchema()

    // 获取所有乘客年龄的平均值
    val avgAge: Double = titanicDF.select($"Age").agg("Age" -> "avg").first().getDouble(0)

    /**
      * TODO: b. 特征工程,提取特征值,组合到标签向量中LabeledPoint
      */
    val titanicRDD: RDD[LabeledPoint] = titanicDF.select(
      $"Survived", $"Pclass", $"Sex", $"Age", $"SibSp", $"Parch", $"Fare"
    ).rdd.map(row => {
      // 获取标签
      val label = row.getInt(0).toDouble

      // TODO: 针对Sex特征进行处理:把Sex变量的取值male替换为1,female替换为0
      val sexFeature = if("male".equals(row.getString(2))) 1.0 else 0.0

      // TODO: 针对Age特征进行转换:有117个乘客年龄值有缺失,用平均年龄30岁替换
      val ageFeature = if(row.get(3) == null) avgAge else row.getDouble(3)

      // 获取特征值
      val features = Vectors.dense(
        Array(row.getInt(1).toDouble, sexFeature, ageFeature,
          row.getInt(4).toDouble, row.getInt(5).toDouble, row.getDouble(6)
        )
      )
      // 返回标签向量
      LabeledPoint(label, features)
    })

    // 划分数据集为两部分:训练数据集和测试数据集
    val Array(trainRDD, testRDD) = titanicRDD.randomSplit(Array(0.8, 0.2))

    /**
      * TODO:c.  使用二分类算法训练模型:SVM、LR、DT和RF、GBT
      */
    // TODO: c.1. 支持向量机
    val svmModel: SVMModel = SVMWithSGD.train(trainRDD, 100)
    val svmPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
      case LabeledPoint(label, features) => (svmModel.predict(features), label)
    }
    val svmMetrics = new BinaryClassificationMetrics(svmPredictionAndLabels)
    println(s"使用SVM预测评估ROC: ${svmMetrics.areaUnderROC()}")

    // TODO: c.2. 逻辑回归
    val lrModel: LogisticRegressionModel = LogisticRegressionWithSGD.train(trainRDD, 100)
    val lrPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
      case LabeledPoint(label, features) => (lrModel.predict(features), label)
    }
    val lrMetrics = new BinaryClassificationMetrics(lrPredictionAndLabels)
    println(s"使用LogisticRegression预测评估ROC: ${lrMetrics.areaUnderROC()}")

    // TODO: c.3. 决策树分类
    val dtcModel = DecisionTree.trainClassifier(
      trainRDD, 2, Map[Int, Int](), "gini", 5, 8
    )
    val dtcPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
      case LabeledPoint(label, features) => (dtcModel.predict(features), label)
    }
    val dtcMetrics = new BinaryClassificationMetrics(dtcPredictionAndLabels)
    println(s"使用DecisionTree预测评估ROC: ${dtcMetrics.areaUnderROC()}")

    // TODO: c.4. 随机森林分类
    val rfcModel = RandomForest.trainClassifier(
      trainRDD, 2, Map[Int, Int](), 10, "sqrt", "gini", 5, 8
    )
    val rfcPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
      case LabeledPoint(label, features) => (rfcModel.predict(features), label)
    }
    val rfcMetrics = new BinaryClassificationMetrics(rfcPredictionAndLabels)
    println(s"使用RandomForest预测评估ROC: ${rfcMetrics.areaUnderROC()}")


    // TODO: c.5. GBT分类(梯度提升集成学习算法训练模型和预测)
    val gbtModel = GradientBoostedTrees.train(
      trainRDD,
      BoostingStrategy(
        new Strategy(Algo.Classification, Gini, 5, 2, 8),
        SquaredError
      )
    )
    val gbtPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
      case LabeledPoint(label, features) => (gbtModel.predict(features), label)
    }
    val gbtMetrics = new BinaryClassificationMetrics(gbtPredictionAndLabels)
    println(s"使用GradientBoostedTrees预测评估ROC: ${gbtMetrics.areaUnderROC()}")

    // 程序休眠,为了方便WEB UI监控
    Thread.sleep(1000000)

    // 关闭资源
    spark.stop()
  }

}