欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

人工智能 - K近邻分类算法的java实现

程序员文章站 2024-03-15 11:18:35
...

K近邻分类算法的java实现 手动输入K值

算法介绍:

关于算法的介绍以及实验的要求在上一个博客已经介绍,这里不再赘述。链接:
人工智能 - K近邻分类算法的Python实现

代码:

package KnnYin;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Scanner;
class DisTy implements Comparable<DisTy>{
     Double dis;
     Integer type;

    public DisTy(Double dis, Integer type) {
        this.dis = dis;
        this.type = type;
    }

    @Override
    public int compareTo(DisTy o) {
        return this.dis.compareTo(o.dis);
    }

    @Override
    public String toString() {
        return "DisTy{" +
                "dis=" + dis +
                ", type=" + type +
                '}';
    }
}

public class Knn {
    public static void main(String[] args) {
        File test = new File("D:/iris-data-testing.txt");
        File tran = new File("D:/iris-data-training.txt");
       // int cnt = 0;
        System.out.println("输入k");
        Scanner scs = new Scanner(System.in);
        int ans = scs.nextInt();
        try {
            Scanner sct = new Scanner(test);
            //System.out.println(scn.nextDouble());
            int[] type = new int[31];
            int cnt = 0;
            for (int i = 0; i < 31; i++) {
                Double[] ts = new Double[4];  //花的前四个数据double
                for (int j = 0; j < 4; j++) {
                    ts[j] = sct.nextDouble();
                    //System.out.print(ts[j]+" ");
                }
                type[i] = sct.nextInt();//花的类型
                //System.out.print(type[i]+" ");
                DisTy[] name = new DisTy[120];

                try {
                    Scanner scn = new Scanner(tran);

                    for (int j = 0; j < 120; j++){
                        Double[] tn = new Double[4];
                        for (int k = 0; k < 4; k++) {
                            tn[k] = scn.nextDouble();

                        }
                        double sum = 0;

                        for (int k = 0; k < 4; k++) {
                            sum += Math.pow(tn[k]-ts[k],2);
                        }

                        int h = scn.nextInt();
                        name[j] = new DisTy(Math.sqrt(sum),h);
                    }

                    int count[] = new int[5];
                    for(int j = 1;j < 4;j++) {
                        count[j] = 0;
                    }
                    Arrays.sort(name);
                    for (int j = 0; j < ans; j++) {
                        count[name[j].type]++;
                    }
                    int bestCount = 0,bestType = -1;
                    for(int j = 1;j<4;j++) {
                        if(count[j]>bestCount) {
                            bestCount = count[j];
                            bestType = j;
                        }
                    }

                    if(bestType == type[i]) {
                        cnt = cnt + 1;
                    }

                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            double r = cnt/31.0;
            System.out.println("正确率为:"+r);

        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }
}

结果:

人工智能 - K近邻分类算法的java实现
想求出所有K值结果看上一个博客。

思路(自定义数组的一个属性对数组进行排序):

思路和上一个差不多也是读取每一个要test的数据,去和训练集中的每一个数据求欧拉距离,求出的每一个数据的距离和类型放进一个类里面。然后就是根据自定义数组的一个属性对数组进行排序,在这里记录一下,备忘:
首先对类实现Comparable<~>接口,重写compareTo方法:

class DisTy implements Comparable<DisTy>{
     Double dis;
     Integer type;

    public DisTy(Double dis, Integer type) { //构造器
        this.dis = dis;
        this.type = type;
    }
    @Override
    public int compareTo(DisTy o) {   //   重写comparTo方法
        return this.dis.compareTo(o.dis);
    }
  }
}

然后用对数组操作的方法:

Arrays.sort(name);

注意在给自定义数组赋值时,要在new的时候初始化传参

DisTy[] name = new DisTy[120];
...
name[j] = new DisTy(Math.sqrt(sum),h);

这样就完成了排序。

另外一种排序方法就是在sort里面重写比较器:

Arrays.sort(peoples,new Comparator() {
@Override
public int compare(People o1,People o2) {
if(o1.weight==o2.weight){
return o2.height-o1.height;
}else{
return o1.weight-o2.weight;
}
}

});

这样的话类就不用实现Comparable<~>接口。