Spark ml pipline交叉验证之朴素贝叶斯
程序员文章站
2022-05-08 16:47:33
...
Spark ml pipline交叉验证之朴素贝叶斯
1.1 模型训练
1.1.1 输入参数
{
"modelName ": "朴素贝叶斯_运动状态预测 ",
"numFolds ": "3 ",
"labelColumn ": "Class ",
"smoothings ": [
0.01,
0.1,
1
]
}
1.1.2 训练代码
import com.cetc.common.conf.MachineLearnModel
import com.cetc.miner.compute.utils.{ModelUtils, Utils}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegressionModel, NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
import scala.collection.JavaConverters._
class NBBestTrain {
val logger: org.apache.log4j.Logger = org.apache.log4j.Logger.getLogger(classOf[LRBestTrain])
/**
* 朴素贝叶斯 分类模型训练
* @param df
* @param id
* @param name
* @param conf
* @param sparkSession
* @return
*/
def execute(df: DataFrame, id: String, name: String, conf: String, sparkSession: SparkSession): java.util.List[MachineLearnModel] = {
df.cache()
logger.info("训练集个数========="+ df.count())
val params = Utils.conf2Class(conf)
//ML的VectorAssembler是一个transformer,要求数据类型不能是string,将多列数据转化为单列的向量列,比如把age、income等等字段列合并成一个 userFea 向量列,方便后续训练
val assembler = new VectorAssembler().setInputCols(df.drop(params.getLabelColumn).columns).setOutputCol("features")
//标准化(归一化)
val standardScaler = new StandardScaler()
.setInputCol(assembler.getOutputCol)
.setOutputCol("scaledFeatures")
.setWithStd(true)//是否将数据缩放到单位标准差。
.setWithMean(false)//是否在缩放前使用平均值对数据进行居中。
//创建线性回归模型
val lr = new NaiveBayes()
.setFeaturesCol(assembler.getOutputCol) // 特征输入
.setLabelCol(params.getLabelColumn) // 要预测的值
//创建机器学习工作流
val pipeline = new Pipeline().setStages(Array(assembler, standardScaler, lr))
//创建多分类评估器,用于基于训练集的多次训练后的模型选择
val classificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol(params.getLabelColumn)//真实值
.setPredictionCol("prediction")//模型预测的值
.setMetricName("accuracy")//正确率
//获取最大迭代次数和正则参数,一共可以训练出(smoothings)个模型
import scala.collection.JavaConversions.asScalaBuffer
val paramMap = new ParamGridBuilder()
.addGrid(lr.getParam("smoothing"), asScalaBuffer(params.getSmoothings))
.build
//创建交叉验证器,他会把训练集分成NumFolds份,然后在其中(NumFolds-1)份里进行训练
//在其中一份里进行测试,针对上面的每一组参数都会训练出NumFolds个模型,最后选择一个
// 最优的模型
val crossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEstimatorParamMaps(paramMap)//设置模型需要的超参数组合
.setNumFolds(params.getNumFolds)//把训练集分成多少份数
.setEvaluator(classificationEvaluator)//设置评估器,用户评估测试结果数据
//模型训练
val model = crossValidator.fit(df)
// 最佳模型
val bestModel = model.bestModel.asInstanceOf[PipelineModel]
val nbModel = bestModel.stages(2).asInstanceOf[NaiveBayesModel]
println("模型类型========", nbModel.getClass)
//将模型封装成对象
val modelObject: MachineLearnModel = ModelUtils.saveModel(nbModel, params.getModelName, 2, conf, 0, 0.0)
//保存模型到数据库
ModelUtils.model2mysql(modelObject)
return List(modelObject).asJava
}
}
1.2 模型评估
1.2.1 输入参数
{"labelColumn":""}
1.2.2 评估代码
import java.util
import com.cetc.common.conf.MachineLearnModel
import com.cetc.miner.compute.utils.{ModelUtils, Utils}
import org.apache.spark.ml.classification.NaiveBayesModel
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.{DataFrame, SparkSession}
class NBAssess {
val logger: org.apache.log4j.Logger = org.apache.log4j.Logger.getLogger(classOf[NBAssess])
/**
* 逻辑回归 分类模型评估
* @param df
* @param model
* @param id
* @param name
* @param conf
* @param sparkSession
* @return
*/
def execute(df: DataFrame, model: MachineLearnModel, id: String, name: String, conf: String, sparkSession: SparkSession): java.util.List[Double] = {
logger.info("测试集个数========="+ df.count())
val params = Utils.conf2Class(conf)
val userProfile = Utils.trans2SupervisedLearning(df, params.getLabelColumn)
val nbModel = ModelUtils.loadModel[NaiveBayesModel](model)
//评估器
val classificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol(params.getLabelColumn)//真实值
.setPredictionCol("prediction")//模型预测的值
.setMetricName("accuracy")//正确率
val testDF = nbModel.transform(userProfile)
testDF.show()
val accuracy = classificationEvaluator.evaluate(testDF)
logger.info("评估结果 正确率 accuracy==============" + accuracy)
ModelUtils.updateModel2mysql(model.getName, accuracy)
val list = new util.ArrayList[Double]()
list.add(accuracy)
return list
}
}