二分类
程序员文章站
2022-05-26 19:09:38
...
数据集:泰坦尼克号的数据集
import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithSGD, SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Gini
import org.apache.spark.mllib.tree.loss.SquaredError
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* 使用Spark MLlib中分类算法,针对泰坦尼克号数据集进行预测生存与否。
*/
object TitanicClassificationMLTest {
def main(args: Array[String]): Unit = {
// TODO: 构建SparkSession实例对象
val spark = SparkSession.builder()
.appName("TitanicClassificationMLTest")
.master("local[4]")
.getOrCreate()
// 导入隐式转换
import spark.implicits._
// 获取SparkContext实例对象
val sc = spark.sparkContext
sc.setLogLevel("WARN")
/**
* TODO: a. 读取泰坦尼克号数据集
*/
val titanicDF: DataFrame = spark.read
.option("header", "true").option("inferSchema", "true")
.csv("G:\\sparkmldata\\train.csv")
// 样本数据
titanicDF.show(10, truncate = false)
titanicDF.printSchema()
// 获取所有乘客年龄的平均值
val avgAge: Double = titanicDF.select($"Age").agg("Age" -> "avg").first().getDouble(0)
/**
* TODO: b. 特征工程,提取特征值,组合到标签向量中LabeledPoint
*/
val titanicRDD: RDD[LabeledPoint] = titanicDF.select(
$"Survived", $"Pclass", $"Sex", $"Age", $"SibSp", $"Parch", $"Fare"
).rdd.map(row => {
// 获取标签
val label = row.getInt(0).toDouble
// TODO: 针对Sex特征进行处理:把Sex变量的取值male替换为1,female替换为0
val sexFeature = if("male".equals(row.getString(2))) 1.0 else 0.0
// TODO: 针对Age特征进行转换:有117个乘客年龄值有缺失,用平均年龄30岁替换
val ageFeature = if(row.get(3) == null) avgAge else row.getDouble(3)
// 获取特征值
val features = Vectors.dense(
Array(row.getInt(1).toDouble, sexFeature, ageFeature,
row.getInt(4).toDouble, row.getInt(5).toDouble, row.getDouble(6)
)
)
// 返回标签向量
LabeledPoint(label, features)
})
// 划分数据集为两部分:训练数据集和测试数据集
val Array(trainRDD, testRDD) = titanicRDD.randomSplit(Array(0.8, 0.2))
/**
* TODO:c. 使用二分类算法训练模型:SVM、LR、DT和RF、GBT
*/
// TODO: c.1. 支持向量机
val svmModel: SVMModel = SVMWithSGD.train(trainRDD, 100)
val svmPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
case LabeledPoint(label, features) => (svmModel.predict(features), label)
}
val svmMetrics = new BinaryClassificationMetrics(svmPredictionAndLabels)
println(s"使用SVM预测评估ROC: ${svmMetrics.areaUnderROC()}")
// TODO: c.2. 逻辑回归
val lrModel: LogisticRegressionModel = LogisticRegressionWithSGD.train(trainRDD, 100)
val lrPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
case LabeledPoint(label, features) => (lrModel.predict(features), label)
}
val lrMetrics = new BinaryClassificationMetrics(lrPredictionAndLabels)
println(s"使用LogisticRegression预测评估ROC: ${lrMetrics.areaUnderROC()}")
// TODO: c.3. 决策树分类
val dtcModel = DecisionTree.trainClassifier(
trainRDD, 2, Map[Int, Int](), "gini", 5, 8
)
val dtcPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
case LabeledPoint(label, features) => (dtcModel.predict(features), label)
}
val dtcMetrics = new BinaryClassificationMetrics(dtcPredictionAndLabels)
println(s"使用DecisionTree预测评估ROC: ${dtcMetrics.areaUnderROC()}")
// TODO: c.4. 随机森林分类
val rfcModel = RandomForest.trainClassifier(
trainRDD, 2, Map[Int, Int](), 10, "sqrt", "gini", 5, 8
)
val rfcPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
case LabeledPoint(label, features) => (rfcModel.predict(features), label)
}
val rfcMetrics = new BinaryClassificationMetrics(rfcPredictionAndLabels)
println(s"使用RandomForest预测评估ROC: ${rfcMetrics.areaUnderROC()}")
// TODO: c.5. GBT分类(梯度提升集成学习算法训练模型和预测)
val gbtModel = GradientBoostedTrees.train(
trainRDD,
BoostingStrategy(
new Strategy(Algo.Classification, Gini, 5, 2, 8),
SquaredError
)
)
val gbtPredictionAndLabels: RDD[(Double, Double)] = testRDD.map{
case LabeledPoint(label, features) => (gbtModel.predict(features), label)
}
val gbtMetrics = new BinaryClassificationMetrics(gbtPredictionAndLabels)
println(s"使用GradientBoostedTrees预测评估ROC: ${gbtMetrics.areaUnderROC()}")
// 程序休眠,为了方便WEB UI监控
Thread.sleep(1000000)
// 关闭资源
spark.stop()
}
}