欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

C#实现KNN算法

程序员文章站 2024-03-16 18:15:34
...
KNN算法的C#代码,上一篇博客中的C#创建KD树的程序中的算法是模仿MATLAB的KDTree的程序思路
这次按照李航老师的《统计学习方法》中的思路,写一个C#程序,其中创建KD树的分割的维度并不是轮寻,而是按照数据的范围来找的
using System;
using System.Collections.Generic;
using System.Linq;


namespace KNNSearch
{
    /// 
    /// Description of KNN.
    /// 
    public class Knn
    {
        /// 
        /// 叶子节点点的个数
        /// 
        private int leafnum = 1;
        /// 
        /// 节点名称集合
        /// 
        private List _nodeNames = new List
        {
            "A",
            "B",
            "C",
            "D",
            "E",
            "F",
            "G",
            "H",
            "I",
            "J",
            "K",
            "L",
            "M",
            "N",
            "O",
            "P",
            "Q",
            "R",
            "S",
            "T",
            "U",
            "V",
            "W",
            "X",
            "Y",
            "Z"
        };
        private List GeneralRawData(int num)
        {
            List rawData = new List();
            Random r = new Random(1);
            for (var i = 0; i < num; i++)
            {
                rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
            }
            //PrintListData(rawData);
            return rawData;
        }

        /// 
        /// 创建KD树
        /// 
        /// 
        /// 
        private Node CreateKdTree(List data)
        {
            // 创建根节点
            Node root = new Node {NodeData = data};
            // 添加当前节点数据
            // 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点
            if (data.Count <= leafnum)
            {
                if (data.Count == 0)
                {
                    return null;
                }
                root.LeftNode = null;
                root.RightNode = null;
                root.Point = data[0];
                root.Splitaxis = -1;
                root.Name = "AA";
                //_nodeNames.RemoveAt(0);
                //Console.WriteLine("叶子节点编号{0}, 数据点编号{1}",root.Name, root.NodeData[0].ID);
                return root;
            }
            // 找到分割轴
            int splitAxis = GetSplitAxis(data);
            // 分割数据
            Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
            root.Splitaxis = splitAxis;
            root.Point = dataSplit.Item1;
            root.Name = "AA";
            //_nodeNames.RemoveAt(0);
            root.LeftNode = CreateKdTree(dataSplit.Item2);
            root.RightNode = CreateKdTree(dataSplit.Item3);
            return root;
        }

        private Tuple, List> GetSplitNum(List data, int splitAxis)
        {
            // 对数据按照第splitAxis排序
            var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList();
            int half = data0.Count / 2;
            List leftdata = new List();
            List rightdata = new List();
            for (int i = 0; i < data0.Count; i++)
            {
                if (i < half)
                {
                    leftdata.Add(data0[i]);
                }
                else if (i > half)
                {
                    rightdata.Add(data0[i]);
                }
            }
            //Console.WriteLine("Split Axis: {0}", splitAxis);
            //PrintListData(data0);
            return new Tuple, List>(data0[half], leftdata, rightdata);
        }
        /// 
        /// 获取分割轴编号
        /// 
        /// 
        /// 
        private int GetSplitAxis(List data)
        {
            // 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式)
            List ranges = new List();
            for (int i = 0; i < 3; i++)
            {
                var i1 = i;
                var xxxData = data.Select(item => Dict[i1](item));
                var enumerable = xxxData as double[] ?? xxxData.ToArray();
                ranges.Add(enumerable.Max() - enumerable.Min());
            }
            var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
            return sorted.Select(x => x.Value).ToList()[0];
        }

        /// 
        /// KNN搜索
        /// 
        /// 
        /// 
        /// 
        private Node KdTreeFindNearest(Node tree, Point target)
        {
            // 搜索路径
            List searchPath = new List();
            // 当前搜索点
            Node searchNode = tree;
            //(1) 从根节点开始往下搜索, 递归的向下访问KD树
            while (searchNode != null)
            {
                // 添加当前节点到搜索路径
                searchPath.Add(searchNode);
                var splitAxis = searchNode.Splitaxis;
                // 若目标点当前维小于节点的阈值,移动至左叶子点,否则移动至右叶子点
                searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode;
            }
            // (2) 以此节点为当前最近节点
            // 最近的点
            Node nearestPoint = searchPath[searchPath.Count - 1];
            // 初值最短距离
            double dist = NearestDist(nearestPoint.NodeData, target);
            // 移除当前点
            searchPath.Remove(nearestPoint);
            // (3). 递归向上回退
            while (searchPath.Count > 0)
            {
                var backNode = searchPath[searchPath.Count - 1]; // 回退节点
                //(a)如果该节点保存的实例点距离目标点的距离比当前最近点更近, 则该点设置为当前最近点
                if (dist > NearestDist(backNode.NodeData, target))
                {
                    nearestPoint = backNode;
                    dist = NearestDist(backNode.NodeData, target);
                    // 如果更近,说明必然在其子节点中
                    var splitaxis = backNode.Splitaxis;
                    // 目标点据当前分割边界的距离

                    var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point));
                    // 如果以最近距离为半径,另外一个子节点位于球的内部,说明最近点位于另外一个叶子节点
                    // 移动至另外一个节点
                    if (distTargetToBound < dist)
                    {
                        // 当前点位于位于该节点的左子节点,需要进入另外一个节点搜索
                        searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode;
                        searchPath.Add(searchNode);
                    }
                }
                searchPath.Remove(backNode);

            }
            return nearestPoint;
        }

        private static Dictionary> Dict => new Dictionary>
        {
            { 0, p => p.X },
            { 1, p => p.Y },
            { 2, p => p.Z },

        };

        public List NodeNames { get => _nodeNames; set => _nodeNames = value; }

        /// 
        /// 计算当前结点实例点距目标点的最近距离
        /// 
        /// 
        /// 
        /// 
        private double NearestDist(List nodeData, Point target)
        {
            List ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
                                                                Math.Pow(item.Y - target.Y, 2) +
                                                                Math.Pow(item.Z - target.Z, 2)))
                .ToList();
            return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
            Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min();
            
        }

        private void PrintListData(List data)
        {
            Console.WriteLine("****************");
            foreach (Point point in data)
            {
                Console.WriteLine(point);
            }
        }
        public Knn()
        {
            List rawData = GeneralRawData(180);
            Node node = CreateKdTree(rawData);
            Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5};
            Node nd = KdTreeFindNearest(node, target);
            // 最短距离为
            double nearestDistFromKnn = NearestDist(nd.NodeData, target);
            Console.WriteLine("通过KNN搜索计算得到的最短距离为{0:F3}", nearestDistFromKnn);
            double nearestDistFromLoop = NearestDist(rawData, target);
            Console.WriteLine("通过KNN遍历计算得到的最短距离为{0:F3}", nearestDistFromLoop);
        }
    }

    /// 
    /// Description of Node.
    /// 
    public class Node
    {
        /// 
        /// 节点名称
        /// 
        public string Name;
        /// 
        /// 切分的阈值点
        /// 
        public Point Point;
        /// 
        /// 左节点
        /// 
        public Node LeftNode;
        /// 
        /// 右节点
        /// 
        public Node RightNode;
        /// 
        /// 节点包含的数据
        /// 
        public List NodeData;
        /// 
        /// 分割轴
        /// 
        public int Splitaxis;
    }

    public class Point
    {
        public double X;
        public double Y;
        public double Z;
        public int ID; // debug用
        public override string ToString()
        {
            return $"({X},{Y},{Z},{ID})";
        }
    }
}