混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现
一:前言
聚类算法是数据挖掘中重要的经常使用的方法,聚类算法的思想是将一个数据集划分成不同的子集,每一个自己称为一个簇,使得簇之间的数据对象差异性比较大,簇内数据对象之间的相似度比较大。常见的算法包括基于划分的方法(k-means,k-prototype)、基于层次的方法(BIRCH)、基于密度的方法(DBSCAN)、基于网格的方法(CLIQUE)等。聚类算法的应用领域非常广泛,比如对银行客户进行聚类,针对不同客户群体开发有针对性的营销策略等。这篇博客主要介绍基于划分方法中的k-prototype算法。
二:算法原理
k-prototype算法结合k-means和k-modes算法而来,可以对具有数值型属性和标称型属性的数据对象进行聚类。对于数值型属性,使用欧氏距离来衡量对象之间的相似性;对于标称型属性,使用汉明距离来衡量对象的相似性。簇中心的计算:对于数值型属性,使用每个簇中数据对象的相应属性的平均值;对于标称性属性,使用每个簇中数据对象相应属性出现次数最多的值。
算法的目的通过不断地迭代是最小化目标函数。
三:算法流程
- 输入K
- 初始化k个簇中心
- 将数据对象划分至每个簇中
- 重新计算每个簇中心
- 重复3-4步骤,直至目标函数不在发生变化。
四:代码实现(java实现)
1,数据对象的定义,这篇博客是以汽车聚类为例,使用了java的注解方法,将属性的分类(标称型属性flag="c",数值型属性flag="n)属性的权重定义好。再将数据对象包装成一个簇中心
@SQLTable(tablename = "car_infor_copy")
public class Car implements Cloneable {
@Column(column = "id")
private int id;
@Column(column = "name")
private String name;
@Column(column = "price",flag="n",weight = 1)
private double price;
@Column(column = "level",flag ="c",weight = 0.01)
private String level;
@Column(column = "type")
private String type;
@Column(column = "ensureance")
private String ensureance;
@Column(column = "oilAss",flag="n",weight = 1)
private double oilAss; //油耗
@Column(column = "length")
private int length;
@Column(column = "width")
private int width;
@Column(column = "height")
private int height;
@Column(column = "wheelDis",flag = "n",weight = 1)
private double wheelDis; //轴距
@Column(column = "weight",flag="n",weight = 1)
private double weight;
@Column(column = "seatNum",flag = "c",weight = 0.01)
private int seatNum;
@Column(column = "paskageSpace")
private int paskageSpace;
@Column(column = "oilSpace")
private int oilSpace;
@Column(column = "elimission",flag="n",weight = 1)
private double elimission; //排量
@Column(column = "kw",flag = "n",weight = 1)
private double kw;//功率
@Column(column = "Ps")
private int Ps;// 马力
@Column(column = "n",flag = "n",weight = 1)
private double n; //扭矩
@Column(column = "airType",flag = "c",weight = 0.01)
private String airType;//进气形式
@Column(column = "cilinderNum",flag="n",weight = 1)
private double cilinderNum;//气缸数
@Column(column = "oilType")
private String oilType; //燃油标号
@Column(column = "boxType",flag = "c",weight = 0.01)
private String boxType; //变速箱类型
@Column(column = "boxNum",flag = "c",weight = 0.01)
private int boxNum;
@Column(column = "driveType",flag = "c",weight = 0.01)
private String driveType;
@Column(column = "firstType")
private String firstType;
@Column(column = "lastType",flag = "c",weight = 0.01)
private String lastType;
@Column(column = "carStructure",flag = "c",weight = 0.01)
private String carStructure;
@Column(column = "ABS")
private int ABS;
@Column(column = "EBD")
private int EBD;
@Column(column = "BA")
private int BA;
@Column(column = "ARS")
private int ARS;
@Column(column = "ESP")
private int ESP;
@Column(column = "airBagBum",flag ="n",weight = 1)
private double airBagBum;
@Column(column = "presureChecker")
private int presureChecker;//胎压监测
@Column(column = "fixSpeed")
private int fixSpeed;
@Column(column = "fixRoad")
private int fixRoad;
@Column(column = "changeAid")
private int changeAid; //并线辅助
@Column(column = "activeBreak")
private int activeBreak;
@Column(column = "tired")
private int tired;
@Column(column = "autoPark")
private int autoPark;
@Column(column = "autoClose")
private int autoClose;
@Column(column = "upAid")
private int upAid;
@Column(column = "downAid")
private int downAid;
@Column(column = "nightSys")
private int nightSys;
@Column(column = "changeableTurn")
private int changeableTurn; //可变转向
@Column(column = "frontRader")
private int frontRader;
@Column(column = "backRader")
private int backRader;
@Column(column = "backVodeo")
private int backVodeo;
@Column(column = "selectMode")
private int selectMode;
@Column(column = "frontLight",flag = "c",weight = 0.01)
private String frontLight;
@Column(column = "autoLight")
private int autoLight;
@Column(column = "upWindow",flag = "c",weight = 0.01)
private String upWindow;
@Column(column = "ePackge")
private int ePackge;
@Column(column = "multiTurner",flag = "c",weight = 0.01)
private int multiTurner; //多功能方向盘
@Column(column = "tunerShift")
private int tunerShift;
@Column(column = "airConditioner",flag = "c",weight = 0.01)
private String airConditioner;
@Column(column = "backAirConditioner",flag = "c",weight = 0.01)
private String backAirConditioner;
@Column(column = "activeNoise")
private int activeNoise;
@Column(column = "seatMeri",flag = "c",weight = 0.01)
private String seatMeri; //座椅材质
@Column(column = "scree",flag = "c",weight = 0.01)
private String scree;
@Column(column = "HUD")
private int HUD;
@Column(column = "GPS",flag ="c",weight = 0.01)
private int GPS;
@Column(column = "carPaly")
private int carPaly;
@Column(column = "voiceCotrol",flag = "c",weight = 0.01)
private int voiceCotrol;
@Column(column = "sysave",flag = "n",weight=1)
private double sysave;
@Column(column = "udAid",flag="n",weight = 1)
private double udAid;
@Column(column = "Raders",flag = "n",weight = 1)
private double Raders;
@Column(column = "distance")
private double distance;
@Column(column = "centerId")
private int centerId;
1,加载数据,我的数据从数据库中读取成出来,放在map集合中。
/**
* 加载数据
*/
public void load(){
cars=new HashMap<Integer,Car>();
centers=new HashMap<Integer, Center>();
try {
List<Car> list = carInforDao.findAll();
for(Car car:list){
cars.put(car.getId(),car);
}
}catch (MsgException e){
System.out.println("数据加载失败");
e.printStackTrace();
}
}
public class Center {
private int id;
private Car center;
private Set<Integer> carId=new HashSet<>();
2,初始化聚类中心,这里是从数据中随机抽取k个数据对象作为簇中心。
/**
* 初始化聚类中心
*/
public void initCenters(int k){
//魔法数字
Random random=new Random(47);
int count=1;
centers=new HashMap<>();
Set<Integer> centerId=new HashSet<>();
Set<Integer> carId=cars.keySet();
while(centerId.size()<k){
int flag=random.nextInt(random.nextInt(2600));
if(carId.contains(flag)){
if(!centerId.contains(flag)){
centerId.add(flag);
}
}
}
for(Integer id:centerId){
Car car=cars.get(id);
Center center=new Center();
center.setId(count);
center.setCenter(copy(car));
centers.put(count++,center);
}
}
public Car copy(Car from){
Car car=null;
car.setId(from.getId());
car.setName(from.getName());
car.setPrice(from.getPrice());
car.setLevel(from.getLevel());
car.setType(from.getType());
car.setEnsureance(from.getEnsureance());
car.setOilAss(from.getOilAss());
car.setLength(from.getLength());
car.setWeight(from.getWeight());
car.setHeight(from.getHeight());
car.setWheelDis(from.getWheelDis());
car.setWidth(from.getWidth());
car.setSeatNum(from.getSeatNum());
car.setPaskageSpace(from.getPaskageSpace());
car.setOilSpace(from.getOilSpace());
car.setElimission(from.getElimission());
car.setKw(from.getKw());
car.setPs(from.getPs());
car.setN(from.getN());
car.setAirType(from.getAirType());
car.setCilinderNum(from.getCilinderNum());
car.setOilType(from.getOilType());
car.setBoxNum(from.getBoxNum());
car.setBoxType(from.getBoxType());
car.setDriveType(from.getDriveType());
car.setFirstType(from.getFirstType());
car.setLastType(from.getLastType());
car.setCarStructure(from.getCarStructure());
car.setABS(from.getABS());
car.setEBD(from.getEBD());
car.setBA(from.getBA());
car.setARS(from.getARS());
car.setESP(from.getESP());
car.setAirBagBum(from.getAirBagBum());
car.setPresureChecker(from.getPresureChecker());
car.setFixSpeed(from.getFixSpeed());
car.setFixRoad(from.getFixRoad());
car.setChangeAid(from.getChangeAid());
car.setActiveBreak(from.getActiveBreak());
car.setTired(from.getTired());
car.setAutoPark(from.getAutoPark());
car.setAutoClose(from.getAutoClose());
car.setUpAid(from.getUpAid());
car.setDownAid(from.getDownAid());
car.setNightSys(from.getNightSys());
car.setChangeableTurn(from.getChangeableTurn());
car.setFrontRader(from.getFrontRader());
car.setBackRader(from.getBackRader());
car.setSelectMode(from.getSelectMode());
car.setBackVodeo(from.getBackVodeo());
car.setFrontLight(from.getFrontLight());
car.setAutoLight(from.getAutoLight());
car.setUpWindow(from.getUpWindow());
car.setEPackge(from.getEPackge());
car.setMultiTurner(from.getMultiTurner());
car.setTunerShift(from.getTunerShift());
car.setAirConditioner(from.getAirConditioner());
car.setBackAirConditioner(from.getBackAirConditioner());
car.setActiveNoise(from.getActiveNoise());
car.setSeatMeri(from.getSeatMeri());
car.setScree(from.getScree());
car.setHUD(from.getHUD());
car.setGPS(from.getGPS());
car.setGPS(from.getGPS());
car.setCarPaly(from.getCarPaly());
car.setVoiceCotrol(from.getVoiceCotrol());
car.setSysave(from.getSysave());
car.setUdAid(from.getUdAid());
car.setDistance(from.getDistance());
car.setCenterId(from.getCenterId());
return car;
}
3,划分数据对象,将对象划分至与其距离最小的簇中。
/**
* 划分对象
*/
public void distribution(){
for(Integer i:cars.keySet()){
Car car=cars.get(i);
double minDis=99999999.0;
int centerId=1;
for(Integer j:centers.keySet()){
double tmp=calculateDistance(car,centers.get(j).getCenter());
if(tmp<minDis){
minDis=tmp;
centerId=j;
}
}
car.setCenterId(centerId);
car.setDistance(minDis);
Set<Integer> centerSet=centers.get(centerId).getCarId();
centerSet.add(i);
}
}
4,计算距离
/**
* 计算一个数据对象与中心的距离
* @param car
* @param centerCar
* @return
*/
public double calculateDistance(Car car,Car centerCar){
Class entityClass=Car.class;
Field [] fields=entityClass.getDeclaredFields();
double distance=0;
for(Field field:fields){
double dis=0;
Column column=field.getAnnotation(Column.class);
if(column!=null){
String flag=column.flag();
double weight=column.weight();
if("n".equals(flag)){
dis=calculateNumericAttr(entityClass,field,car,centerCar)*weight;
}else if("c".equals(flag)){
dis=calculateCatelogeAttr(entityClass,field,car,centerCar)*weight;
}
}
distance+=dis;
}
return distance;
}
/**
* 计算数值型属性与中心的距离
* @param entityClass
* @param field
* @param car
* @param center
* @return
*/
public double calculateNumericAttr(Class entityClass,Field field,Car car,Car center){
String fieldName=field.getName();
String getMedthod="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
Method method=null;
double distance=0;
try {
method = entityClass.getMethod(getMedthod, null);
}catch (NoSuchMethodException e){
System.out.println("没有找到get方法");
e.printStackTrace();
}
if(method!=null){
double carValue=0;
double centerValue=0;
try {
carValue = (double)method.invoke(car, null);
centerValue=(double)method.invoke(center,null);
}catch (IllegalAccessException e){
System.out.println("获取属性时出错!1");
e.printStackTrace();
}catch (Exception e){
System.out.println("获取属性时出错!2");
e.printStackTrace();
}
// distance=Math.sqrt((carValue-centerValue)*(carValue-centerValue));
distance=(carValue-centerValue)*(carValue-centerValue);
}
return distance;
}
/**
* 计算分类型属性与中心的距离
* @param entityClass
* @param field
* @param car
* @param center
* @return
*/
public double calculateCatelogeAttr(Class entityClass,Field field,Car car,Car center){
String fieldName=field.getName();
String getMedthod="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
Method method=null;
double distance=0;
try {
method = entityClass.getMethod(getMedthod, null);
}catch (NoSuchMethodException e){
System.out.println("没有找到get方法");
e.printStackTrace();
}
if(method!=null){
Class<?> returnType=method.getReturnType();
String returnTypeName=returnType.getSimpleName();
if("int".equals(returnTypeName)){
int carValue=0;
int centerValue=0;
try{
carValue=(int)method.invoke(car,null);
centerValue=(int)method.invoke(center,null);
}catch (Exception e) {
System.out.println("读取分类型属性时出现错误,整数类型");
e.printStackTrace();
}
if(carValue==centerValue){
distance=0;
}else {
distance=1;
}
}else if("String".equals(returnTypeName)){
String carValue="";
String centerValue="";
try{
carValue=(String)method.invoke(car,null);
if(carValue==null){
carValue="-";
}
centerValue=(String)method.invoke(center,null);
if(centerValue==null){
centerValue="-";
}
}catch(Exception e){
System.out.println("读取分类型属性时出现错误,字符串类型");
e.printStackTrace();
}
if(carValue.equals(centerValue)){
distance=0;
}else {
distance=1;
}
}
}
return distance;
}
5,重新计算簇中心
/**
* 重新计算簇中心
*/
public void reSetCenter(){
for(Integer i:centers.keySet()){
Center center=centers.get(i);
Set<Integer> carsId=center.getCarId();
reSetCenterAttr(center,carsId);
//清空每个簇中的车型,准备下次划分
carsId.clear();
}
}
public void reSetCenterAttr(Center center,Set<Integer> carsId){
Class entityClass=Car.class;
Field [] fields=entityClass.getDeclaredFields();
for(Field field:fields){
Column column=field.getAnnotation(Column.class);
if(column!=null){
String flag=column.flag();
if("c".equals(flag)){
reSetCenterAttrOfCate(field,center.getCenter(),carsId);
}else if("n".equals(flag)){
reSetCenterAttrOfNum(field,center.getCenter(),carsId);
}
}
}
}
public void reSetCenterAttrOfCate(Field field,Car center,Set<Integer> carsId){
Class entityClass=Car.class;
String fieldName=field.getName();
Class<?> paraType=null;
String getMedthodName="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
String setMedthodName="set"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
Method method=null;
Method setMethod=null;
String returnTypeName=null;
try{
paraType=field.getType();
method=entityClass.getMethod(getMedthodName);
setMethod=entityClass.getMethod(setMedthodName,paraType);
}catch (Exception e){
e.printStackTrace();
}
if(method!=null){
Class<?> returnType=method.getReturnType();
returnTypeName=returnType.getSimpleName();
}
if("int".equals(returnTypeName)){
Map<Integer,Integer> recorder=new HashMap<>();
for(Integer i:carsId){
try {
int key = (int)method.invoke(cars.get(i), null);
if(recorder.keySet().contains(key)){
int tmp=recorder.get(key);
recorder.put(key,tmp+1);
}else {
recorder.put(key,1);
}
}catch (Exception e){
System.out.println("更新簇中心时读取数据对象出错");
}
//找出出现次数最多(或者频率最高)的属性
int maxId=0;
int valueOfMaxId=0;
for(Integer j:recorder.keySet()){
if(valueOfMaxId<recorder.get(j)){
maxId=j;
valueOfMaxId=recorder.get(j);
}
}
//设置类的属性
try{
setMethod.invoke(center,maxId);
}catch (Exception e){
System.out.println("设置簇中心属性时出现错误!:属性为"+fieldName);
e.printStackTrace();
}
}
}else if("String".equals(returnTypeName)){
Map<String,Integer> recorder=new HashMap<>();
for(Integer i:carsId) {
try {
String key = (String) method.invoke(cars.get(i), null);
if (key == null) {
key = "-";
}
if (recorder.keySet().contains(key)) {
int tmp = recorder.get(key);
recorder.put(key, tmp + 1);
} else {
recorder.put(key, 1);
}
} catch (Exception e) {
System.out.println("更新簇中心时读取数据对象出错");
}
}
//找出出现次数最多(或者频率最高)的属性
String maxKey="";
Integer valueOfMaxId=0;
for(String j:recorder.keySet()){
if(valueOfMaxId<recorder.get(j)){
valueOfMaxId=recorder.get(j);
maxKey=j;
}
}
//设置类的属性
try{
setMethod.invoke(center,maxKey);
}catch (Exception e){
System.out.println("设置簇中心属性时出现错误!:属性为"+fieldName);
e.printStackTrace();
}
}
}
public void reSetCenterAttrOfNum(Field field,Car center,Set<Integer> carsId){
Class entityClass=Car.class;
String fieldName=fieldName=field.getName();;
Class<?> paratype=null ;
String getMedthodName="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
String setMedthodName="set"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
Method method=null;
Method setMethod=null;
String returnTypeName=null;
try{
paratype=field.getType();
method=entityClass.getMethod(getMedthodName);
setMethod=entityClass.getMethod(setMedthodName,paratype);
}catch (Exception e){
e.printStackTrace();
}
double value=0.0;
for(Integer i:carsId){
try{
double tmp=(double)method.invoke(cars.get(i),null);
value+=tmp;
}catch (Exception e){
System.out.println("设置簇中心数值型属性,取出数据对象数据时出错"+fieldName);
e.printStackTrace();
}
}
try{
setMethod.invoke(center,Double.valueOf(value/carsId.size()));
}catch (Exception e){
System.out.println("设置簇中心数值型属性时最后一步出错!flag "+fieldName);
e.printStackTrace();
}
}
五:实验结果
1,博主对634中车型进行聚类,数据预处理对数值型属性做了归一化处理,可视化结果,这里选取了其中20个簇展示。
2,目标函数,这里迭代的次数为10(感觉有待改进)
六:总结
1,k-prototype算法可以处理混合数据类型的聚类问题。
2,k-prototype算法容易受到簇中心选取的影响。
下一篇: K-means 算法原理