欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现

程序员文章站 2022-07-14 19:22:30
...

 一:前言

   聚类算法是数据挖掘中重要的经常使用的方法,聚类算法的思想是将一个数据集划分成不同的子集,每一个自己称为一个簇,使得簇之间的数据对象差异性比较大,簇内数据对象之间的相似度比较大。常见的算法包括基于划分的方法(k-means,k-prototype)、基于层次的方法(BIRCH)、基于密度的方法(DBSCAN)、基于网格的方法(CLIQUE)等。聚类算法的应用领域非常广泛,比如对银行客户进行聚类,针对不同客户群体开发有针对性的营销策略等。这篇博客主要介绍基于划分方法中的k-prototype算法。

二:算法原理

  k-prototype算法结合k-means和k-modes算法而来,可以对具有数值型属性和标称型属性的数据对象进行聚类。对于数值型属性,使用欧氏距离来衡量对象之间的相似性;对于标称型属性,使用汉明距离来衡量对象的相似性。簇中心的计算:对于数值型属性,使用每个簇中数据对象的相应属性的平均值;对于标称性属性,使用每个簇中数据对象相应属性出现次数最多的值。

混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现


混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现

算法的目的通过不断地迭代是最小化目标函数。

三:算法流程

  1. 输入K
  2. 初始化k个簇中心
  3. 将数据对象划分至每个簇中
  4. 重新计算每个簇中心
  5. 重复3-4步骤,直至目标函数不在发生变化。

四:代码实现(java实现)

1,数据对象的定义,这篇博客是以汽车聚类为例,使用了java的注解方法,将属性的分类(标称型属性flag="c",数值型属性flag="n)属性的权重定义好。再将数据对象包装成一个簇中心

@SQLTable(tablename = "car_infor_copy")
public class Car implements Cloneable {
    @Column(column = "id")
    private int id;
    @Column(column = "name")
    private String name;
    @Column(column = "price",flag="n",weight = 1)
    private double price;
    @Column(column = "level",flag ="c",weight = 0.01)
    private String level;
    @Column(column = "type")
    private String type;
    @Column(column = "ensureance")
    private String ensureance;
    @Column(column = "oilAss",flag="n",weight = 1)
    private double oilAss;    //油耗
    @Column(column = "length")
    private int length;
    @Column(column = "width")
    private int width;
    @Column(column = "height")
    private int height;
    @Column(column = "wheelDis",flag = "n",weight = 1)
    private double wheelDis; //轴距
    @Column(column = "weight",flag="n",weight = 1)
    private double weight;
    @Column(column = "seatNum",flag = "c",weight = 0.01)
    private int seatNum;
    @Column(column = "paskageSpace")
    private int paskageSpace;
    @Column(column = "oilSpace")
    private int oilSpace;
    @Column(column = "elimission",flag="n",weight = 1)
    private double elimission;  //排量
    @Column(column = "kw",flag = "n",weight = 1)
    private double kw;//功率
    @Column(column = "Ps")
    private int Ps;// 马力
    @Column(column = "n",flag = "n",weight = 1)
    private double n; //扭矩
    @Column(column = "airType",flag = "c",weight = 0.01)
    private String airType;//进气形式
    @Column(column = "cilinderNum",flag="n",weight = 1)
    private double cilinderNum;//气缸数
    @Column(column = "oilType")
    private String oilType; //燃油标号
    @Column(column = "boxType",flag = "c",weight = 0.01)
    private String boxType; //变速箱类型
    @Column(column = "boxNum",flag = "c",weight = 0.01)
    private int boxNum;
    @Column(column = "driveType",flag = "c",weight = 0.01)
    private String driveType;
    @Column(column = "firstType")
    private String firstType;
    @Column(column = "lastType",flag = "c",weight = 0.01)
    private String lastType;
    @Column(column = "carStructure",flag = "c",weight = 0.01)
    private String carStructure;
    @Column(column = "ABS")
    private int ABS;
    @Column(column = "EBD")
    private int EBD;
    @Column(column = "BA")
    private int BA;
    @Column(column = "ARS")
    private int ARS;
    @Column(column = "ESP")
    private int ESP;
    @Column(column = "airBagBum",flag ="n",weight = 1)
    private double airBagBum;
    @Column(column = "presureChecker")
    private int presureChecker;//胎压监测
    @Column(column = "fixSpeed")
    private int fixSpeed;
    @Column(column = "fixRoad")
    private int fixRoad;
    @Column(column = "changeAid")
    private int changeAid; //并线辅助
    @Column(column = "activeBreak")
    private int activeBreak;
    @Column(column = "tired")
    private int tired;
    @Column(column = "autoPark")
    private int autoPark;
    @Column(column = "autoClose")
    private int autoClose;
    @Column(column = "upAid")
    private int upAid;
    @Column(column = "downAid")
    private int downAid;
    @Column(column = "nightSys")
    private int nightSys;
    @Column(column = "changeableTurn")
    private int changeableTurn; //可变转向
    @Column(column = "frontRader")
    private int frontRader;
    @Column(column = "backRader")
    private int backRader;
    @Column(column = "backVodeo")
    private int backVodeo;
    @Column(column = "selectMode")
    private int selectMode;
    @Column(column = "frontLight",flag = "c",weight = 0.01)
    private String frontLight;
    @Column(column = "autoLight")
    private int autoLight;
    @Column(column = "upWindow",flag = "c",weight = 0.01)
    private String upWindow;
    @Column(column = "ePackge")
    private int ePackge;
    @Column(column = "multiTurner",flag = "c",weight = 0.01)
    private int multiTurner; //多功能方向盘
    @Column(column = "tunerShift")
    private int tunerShift;
    @Column(column = "airConditioner",flag = "c",weight = 0.01)
    private String airConditioner;
    @Column(column = "backAirConditioner",flag = "c",weight = 0.01)
    private String backAirConditioner;
    @Column(column = "activeNoise")
    private int activeNoise;
    @Column(column = "seatMeri",flag = "c",weight = 0.01)
    private String seatMeri;  //座椅材质
    @Column(column = "scree",flag = "c",weight = 0.01)
    private String scree;
    @Column(column = "HUD")
    private int HUD;
    @Column(column = "GPS",flag ="c",weight = 0.01)
    private int GPS;
    @Column(column = "carPaly")
    private int carPaly;
    @Column(column = "voiceCotrol",flag = "c",weight = 0.01)
    private int voiceCotrol;
    @Column(column = "sysave",flag = "n",weight=1)
    private double sysave;
    @Column(column = "udAid",flag="n",weight = 1)
    private double udAid;
    @Column(column = "Raders",flag = "n",weight = 1)
    private double Raders;
    @Column(column = "distance")
    private double distance;
    @Column(column = "centerId")
    private int centerId;

1,加载数据,我的数据从数据库中读取成出来,放在map集合中。

/**
     * 加载数据
     */
    public void load(){
        cars=new HashMap<Integer,Car>();
        centers=new HashMap<Integer, Center>();
        try {
            List<Car> list = carInforDao.findAll();
            for(Car car:list){
                cars.put(car.getId(),car);
            }
        }catch (MsgException e){
            System.out.println("数据加载失败");
            e.printStackTrace();
        }

    }
public class Center {

    private int id;

    private Car center;

    private Set<Integer> carId=new HashSet<>();

2,初始化聚类中心,这里是从数据中随机抽取k个数据对象作为簇中心。

/**
     * 初始化聚类中心
     */
    public void initCenters(int k){
        //魔法数字
        Random random=new Random(47);
        int count=1;
        centers=new HashMap<>();
        Set<Integer> centerId=new HashSet<>();
        Set<Integer> carId=cars.keySet();
        while(centerId.size()<k){
            int flag=random.nextInt(random.nextInt(2600));
            if(carId.contains(flag)){
                if(!centerId.contains(flag)){
                    centerId.add(flag);
                }
            }

        }
        for(Integer id:centerId){
            Car car=cars.get(id);
            Center center=new Center();
            center.setId(count);
            center.setCenter(copy(car));
            centers.put(count++,center);
        }


    }

    public Car copy(Car from){
        Car car=null;
        car.setId(from.getId());
        car.setName(from.getName());
        car.setPrice(from.getPrice());
        car.setLevel(from.getLevel());
        car.setType(from.getType());
        car.setEnsureance(from.getEnsureance());
        car.setOilAss(from.getOilAss());
        car.setLength(from.getLength());
        car.setWeight(from.getWeight());
        car.setHeight(from.getHeight());
        car.setWheelDis(from.getWheelDis());
        car.setWidth(from.getWidth());
        car.setSeatNum(from.getSeatNum());
        car.setPaskageSpace(from.getPaskageSpace());
        car.setOilSpace(from.getOilSpace());
        car.setElimission(from.getElimission());
        car.setKw(from.getKw());
        car.setPs(from.getPs());
        car.setN(from.getN());
        car.setAirType(from.getAirType());
        car.setCilinderNum(from.getCilinderNum());
        car.setOilType(from.getOilType());
        car.setBoxNum(from.getBoxNum());
        car.setBoxType(from.getBoxType());

        car.setDriveType(from.getDriveType());
        car.setFirstType(from.getFirstType());
        car.setLastType(from.getLastType());
        car.setCarStructure(from.getCarStructure());
        car.setABS(from.getABS());
        car.setEBD(from.getEBD());
        car.setBA(from.getBA());
        car.setARS(from.getARS());
        car.setESP(from.getESP());
        car.setAirBagBum(from.getAirBagBum());
        car.setPresureChecker(from.getPresureChecker());
        car.setFixSpeed(from.getFixSpeed());
        car.setFixRoad(from.getFixRoad());
        car.setChangeAid(from.getChangeAid());
        car.setActiveBreak(from.getActiveBreak());
        car.setTired(from.getTired());
        car.setAutoPark(from.getAutoPark());
        car.setAutoClose(from.getAutoClose());
        car.setUpAid(from.getUpAid());
        car.setDownAid(from.getDownAid());
        car.setNightSys(from.getNightSys());
        car.setChangeableTurn(from.getChangeableTurn());
        car.setFrontRader(from.getFrontRader());
        car.setBackRader(from.getBackRader());
        car.setSelectMode(from.getSelectMode());
        car.setBackVodeo(from.getBackVodeo());
        car.setFrontLight(from.getFrontLight());
        car.setAutoLight(from.getAutoLight());
        car.setUpWindow(from.getUpWindow());
        car.setEPackge(from.getEPackge());
        car.setMultiTurner(from.getMultiTurner());
        car.setTunerShift(from.getTunerShift());
        car.setAirConditioner(from.getAirConditioner());
        car.setBackAirConditioner(from.getBackAirConditioner());
        car.setActiveNoise(from.getActiveNoise());
        car.setSeatMeri(from.getSeatMeri());
        car.setScree(from.getScree());
        car.setHUD(from.getHUD());
        car.setGPS(from.getGPS());
        car.setGPS(from.getGPS());
        car.setCarPaly(from.getCarPaly());
        car.setVoiceCotrol(from.getVoiceCotrol());
        car.setSysave(from.getSysave());
        car.setUdAid(from.getUdAid());
        car.setDistance(from.getDistance());
        car.setCenterId(from.getCenterId());
        return car;

    }

3,划分数据对象,将对象划分至与其距离最小的簇中。

 /**
     * 划分对象
     */
    public void distribution(){
        for(Integer i:cars.keySet()){
            Car car=cars.get(i);
            double minDis=99999999.0;
            int centerId=1;
            for(Integer j:centers.keySet()){
                double tmp=calculateDistance(car,centers.get(j).getCenter());
                if(tmp<minDis){
                    minDis=tmp;
                    centerId=j;
                }
            }
            car.setCenterId(centerId);
            car.setDistance(minDis);
            Set<Integer> centerSet=centers.get(centerId).getCarId();
            centerSet.add(i);

        }
    }

4,计算距离

/**
     * 计算一个数据对象与中心的距离
     * @param car
     * @param centerCar
     * @return
     */
    public double calculateDistance(Car car,Car centerCar){
        Class entityClass=Car.class;
        Field [] fields=entityClass.getDeclaredFields();
        double distance=0;
        for(Field field:fields){
            double dis=0;
            Column column=field.getAnnotation(Column.class);
            if(column!=null){
                String flag=column.flag();
                double weight=column.weight();
                if("n".equals(flag)){
                   dis=calculateNumericAttr(entityClass,field,car,centerCar)*weight;
                }else if("c".equals(flag)){
                    dis=calculateCatelogeAttr(entityClass,field,car,centerCar)*weight;

                }
            }
            distance+=dis;
        }
        return distance;
    }

    /**
     * 计算数值型属性与中心的距离
     * @param entityClass
     * @param field
     * @param car
     * @param center
     * @return
     */

    public double calculateNumericAttr(Class entityClass,Field field,Car car,Car center){
        String fieldName=field.getName();
        String getMedthod="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        Method method=null;
        double distance=0;
        try {
            method = entityClass.getMethod(getMedthod, null);
        }catch (NoSuchMethodException e){
            System.out.println("没有找到get方法");
            e.printStackTrace();

        }
        if(method!=null){
            double carValue=0;
            double centerValue=0;
            try {
                 carValue = (double)method.invoke(car, null);
                 centerValue=(double)method.invoke(center,null);
            }catch (IllegalAccessException e){
                System.out.println("获取属性时出错!1");
                e.printStackTrace();
            }catch (Exception e){
                System.out.println("获取属性时出错!2");
                e.printStackTrace();
            }
           // distance=Math.sqrt((carValue-centerValue)*(carValue-centerValue));
            distance=(carValue-centerValue)*(carValue-centerValue);
        }
        return distance;
    }

    /**
     * 计算分类型属性与中心的距离
     * @param entityClass
     * @param field
     * @param car
     * @param center
     * @return
     */
    public double calculateCatelogeAttr(Class entityClass,Field field,Car car,Car center){
        String fieldName=field.getName();
        String getMedthod="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        Method method=null;
        double distance=0;
        try {
            method = entityClass.getMethod(getMedthod, null);
        }catch (NoSuchMethodException e){
            System.out.println("没有找到get方法");
            e.printStackTrace();

        }
        if(method!=null){
             Class<?> returnType=method.getReturnType();
             String returnTypeName=returnType.getSimpleName();
             if("int".equals(returnTypeName)){
                 int carValue=0;
                 int centerValue=0;
                 try{
                     carValue=(int)method.invoke(car,null);
                     centerValue=(int)method.invoke(center,null);
                 }catch (Exception e) {
                     System.out.println("读取分类型属性时出现错误,整数类型");
                     e.printStackTrace();
                 }
                 if(carValue==centerValue){
                     distance=0;
                 }else {
                     distance=1;
                 }
             }else if("String".equals(returnTypeName)){
                 String carValue="";
                 String centerValue="";
                 try{
                     carValue=(String)method.invoke(car,null);
                     if(carValue==null){
                         carValue="-";
                     }
                     centerValue=(String)method.invoke(center,null);
                     if(centerValue==null){
                         centerValue="-";
                     }
                 }catch(Exception e){
                     System.out.println("读取分类型属性时出现错误,字符串类型");
                     e.printStackTrace();
                 }
                 if(carValue.equals(centerValue)){
                     distance=0;
                 }else {
                     distance=1;
                 }
             }
        }
        return distance;

    }

5,重新计算簇中心

/**
     * 重新计算簇中心
     */
    public void reSetCenter(){
        for(Integer i:centers.keySet()){
            Center center=centers.get(i);
            Set<Integer> carsId=center.getCarId();
            reSetCenterAttr(center,carsId);
            //清空每个簇中的车型,准备下次划分
            carsId.clear();

        }
    }

    public void reSetCenterAttr(Center center,Set<Integer> carsId){
        Class entityClass=Car.class;
        Field [] fields=entityClass.getDeclaredFields();
        for(Field field:fields){
            Column column=field.getAnnotation(Column.class);
            if(column!=null){
                String flag=column.flag();
                if("c".equals(flag)){
                    reSetCenterAttrOfCate(field,center.getCenter(),carsId);
                }else if("n".equals(flag)){
                    reSetCenterAttrOfNum(field,center.getCenter(),carsId);

                }
            }
        }

    }

    public void reSetCenterAttrOfCate(Field field,Car center,Set<Integer> carsId){
        Class entityClass=Car.class;
        String fieldName=field.getName();
        Class<?> paraType=null;
        String getMedthodName="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        String setMedthodName="set"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        Method method=null;
        Method setMethod=null;
        String returnTypeName=null;
        try{
            paraType=field.getType();
            method=entityClass.getMethod(getMedthodName);
            setMethod=entityClass.getMethod(setMedthodName,paraType);
        }catch (Exception e){
            e.printStackTrace();
        }
        if(method!=null){
            Class<?> returnType=method.getReturnType();
            returnTypeName=returnType.getSimpleName();
        }
        if("int".equals(returnTypeName)){
            Map<Integer,Integer> recorder=new HashMap<>();
            for(Integer i:carsId){
                try {
                    int key = (int)method.invoke(cars.get(i), null);
                    if(recorder.keySet().contains(key)){
                        int tmp=recorder.get(key);
                        recorder.put(key,tmp+1);
                    }else {
                        recorder.put(key,1);
                    }
                }catch (Exception e){
                    System.out.println("更新簇中心时读取数据对象出错");
                }
                //找出出现次数最多(或者频率最高)的属性
                int maxId=0;
                int valueOfMaxId=0;
                for(Integer j:recorder.keySet()){
                    if(valueOfMaxId<recorder.get(j)){
                        maxId=j;
                        valueOfMaxId=recorder.get(j);
                    }

                }
                //设置类的属性
                try{
                    setMethod.invoke(center,maxId);
                }catch (Exception e){
                    System.out.println("设置簇中心属性时出现错误!:属性为"+fieldName);
                    e.printStackTrace();
                }
            }


        }else if("String".equals(returnTypeName)){
            Map<String,Integer> recorder=new HashMap<>();
            for(Integer i:carsId) {
                try {
                    String key = (String) method.invoke(cars.get(i), null);
                    if (key == null) {
                        key = "-";
                    }
                    if (recorder.keySet().contains(key)) {
                        int tmp = recorder.get(key);
                        recorder.put(key, tmp + 1);
                    } else {
                        recorder.put(key, 1);
                    }
                } catch (Exception e) {
                    System.out.println("更新簇中心时读取数据对象出错");
                }
            }
                //找出出现次数最多(或者频率最高)的属性
                String maxKey="";
                Integer valueOfMaxId=0;
                for(String j:recorder.keySet()){
                    if(valueOfMaxId<recorder.get(j)){
                        valueOfMaxId=recorder.get(j);
                        maxKey=j;
                    }
                }
                //设置类的属性
                try{
                    setMethod.invoke(center,maxKey);
                }catch (Exception e){
                    System.out.println("设置簇中心属性时出现错误!:属性为"+fieldName);
                    e.printStackTrace();
                }
            }




    }
    public void reSetCenterAttrOfNum(Field field,Car center,Set<Integer> carsId){
        Class entityClass=Car.class;
        String fieldName=fieldName=field.getName();;
        Class<?> paratype=null ;
        String getMedthodName="get"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        String setMedthodName="set"+fieldName.substring(0,1).toUpperCase()+fieldName.substring(1);
        Method method=null;
        Method setMethod=null;
        String returnTypeName=null;
        try{
            paratype=field.getType();

            method=entityClass.getMethod(getMedthodName);
            setMethod=entityClass.getMethod(setMedthodName,paratype);
        }catch (Exception e){
            e.printStackTrace();
        }
        double value=0.0;
        for(Integer i:carsId){
            try{
                double tmp=(double)method.invoke(cars.get(i),null);
                value+=tmp;
            }catch (Exception e){
                System.out.println("设置簇中心数值型属性,取出数据对象数据时出错"+fieldName);
                e.printStackTrace();
            }
        }
        try{
            setMethod.invoke(center,Double.valueOf(value/carsId.size()));
        }catch (Exception e){
            System.out.println("设置簇中心数值型属性时最后一步出错!flag "+fieldName);
            e.printStackTrace();
        }


    }

五:实验结果

1,博主对634中车型进行聚类,数据预处理对数值型属性做了归一化处理,可视化结果,这里选取了其中20个簇展示。


混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现

2,目标函数,这里迭代的次数为10(感觉有待改进)


混合属性聚类算法——以汽车属性为维度进行对汽车车型进行聚类 java实现


六:总结

1,k-prototype算法可以处理混合数据类型的聚类问题。

2,k-prototype算法容易受到簇中心选取的影响。