C#实现KNN算法
程序员文章站
2024-03-16 18:15:34
...
KNN算法的C#代码,上一篇博客中的C#创建KD树的程序中的算法是模仿MATLAB的KDTree的程序思路
这次按照李航老师的《统计学习方法》中的思路,写一个C#程序,其中创建KD树的分割的维度并不是轮寻,而是按照数据的范围来找的
using System;
using System.Collections.Generic;
using System.Linq;
namespace KNNSearch
{
///
/// Description of KNN.
///
public class Knn
{
///
/// 叶子节点点的个数
///
private int leafnum = 1;
///
/// 节点名称集合
///
private List _nodeNames = new List
{
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z"
};
private List GeneralRawData(int num)
{
List rawData = new List();
Random r = new Random(1);
for (var i = 0; i < num; i++)
{
rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
}
//PrintListData(rawData);
return rawData;
}
///
/// 创建KD树
///
///
///
private Node CreateKdTree(List data)
{
// 创建根节点
Node root = new Node {NodeData = data};
// 添加当前节点数据
// 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点
if (data.Count <= leafnum)
{
if (data.Count == 0)
{
return null;
}
root.LeftNode = null;
root.RightNode = null;
root.Point = data[0];
root.Splitaxis = -1;
root.Name = "AA";
//_nodeNames.RemoveAt(0);
//Console.WriteLine("叶子节点编号{0}, 数据点编号{1}",root.Name, root.NodeData[0].ID);
return root;
}
// 找到分割轴
int splitAxis = GetSplitAxis(data);
// 分割数据
Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
root.Splitaxis = splitAxis;
root.Point = dataSplit.Item1;
root.Name = "AA";
//_nodeNames.RemoveAt(0);
root.LeftNode = CreateKdTree(dataSplit.Item2);
root.RightNode = CreateKdTree(dataSplit.Item3);
return root;
}
private Tuple, List> GetSplitNum(List data, int splitAxis)
{
// 对数据按照第splitAxis排序
var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList();
int half = data0.Count / 2;
List leftdata = new List();
List rightdata = new List();
for (int i = 0; i < data0.Count; i++)
{
if (i < half)
{
leftdata.Add(data0[i]);
}
else if (i > half)
{
rightdata.Add(data0[i]);
}
}
//Console.WriteLine("Split Axis: {0}", splitAxis);
//PrintListData(data0);
return new Tuple, List>(data0[half], leftdata, rightdata);
}
///
/// 获取分割轴编号
///
///
///
private int GetSplitAxis(List data)
{
// 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式)
List ranges = new List();
for (int i = 0; i < 3; i++)
{
var i1 = i;
var xxxData = data.Select(item => Dict[i1](item));
var enumerable = xxxData as double[] ?? xxxData.ToArray();
ranges.Add(enumerable.Max() - enumerable.Min());
}
var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
return sorted.Select(x => x.Value).ToList()[0];
}
///
/// KNN搜索
///
///
///
///
private Node KdTreeFindNearest(Node tree, Point target)
{
// 搜索路径
List searchPath = new List();
// 当前搜索点
Node searchNode = tree;
//(1) 从根节点开始往下搜索, 递归的向下访问KD树
while (searchNode != null)
{
// 添加当前节点到搜索路径
searchPath.Add(searchNode);
var splitAxis = searchNode.Splitaxis;
// 若目标点当前维小于节点的阈值,移动至左叶子点,否则移动至右叶子点
searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode;
}
// (2) 以此节点为当前最近节点
// 最近的点
Node nearestPoint = searchPath[searchPath.Count - 1];
// 初值最短距离
double dist = NearestDist(nearestPoint.NodeData, target);
// 移除当前点
searchPath.Remove(nearestPoint);
// (3). 递归向上回退
while (searchPath.Count > 0)
{
var backNode = searchPath[searchPath.Count - 1]; // 回退节点
//(a)如果该节点保存的实例点距离目标点的距离比当前最近点更近, 则该点设置为当前最近点
if (dist > NearestDist(backNode.NodeData, target))
{
nearestPoint = backNode;
dist = NearestDist(backNode.NodeData, target);
// 如果更近,说明必然在其子节点中
var splitaxis = backNode.Splitaxis;
// 目标点据当前分割边界的距离
var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point));
// 如果以最近距离为半径,另外一个子节点位于球的内部,说明最近点位于另外一个叶子节点
// 移动至另外一个节点
if (distTargetToBound < dist)
{
// 当前点位于位于该节点的左子节点,需要进入另外一个节点搜索
searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode;
searchPath.Add(searchNode);
}
}
searchPath.Remove(backNode);
}
return nearestPoint;
}
private static Dictionary> Dict => new Dictionary>
{
{ 0, p => p.X },
{ 1, p => p.Y },
{ 2, p => p.Z },
};
public List NodeNames { get => _nodeNames; set => _nodeNames = value; }
///
/// 计算当前结点实例点距目标点的最近距离
///
///
///
///
private double NearestDist(List nodeData, Point target)
{
List ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
Math.Pow(item.Y - target.Y, 2) +
Math.Pow(item.Z - target.Z, 2)))
.ToList();
return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min();
}
private void PrintListData(List data)
{
Console.WriteLine("****************");
foreach (Point point in data)
{
Console.WriteLine(point);
}
}
public Knn()
{
List rawData = GeneralRawData(180);
Node node = CreateKdTree(rawData);
Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5};
Node nd = KdTreeFindNearest(node, target);
// 最短距离为
double nearestDistFromKnn = NearestDist(nd.NodeData, target);
Console.WriteLine("通过KNN搜索计算得到的最短距离为{0:F3}", nearestDistFromKnn);
double nearestDistFromLoop = NearestDist(rawData, target);
Console.WriteLine("通过KNN遍历计算得到的最短距离为{0:F3}", nearestDistFromLoop);
}
}
///
/// Description of Node.
///
public class Node
{
///
/// 节点名称
///
public string Name;
///
/// 切分的阈值点
///
public Point Point;
///
/// 左节点
///
public Node LeftNode;
///
/// 右节点
///
public Node RightNode;
///
/// 节点包含的数据
///
public List NodeData;
///
/// 分割轴
///
public int Splitaxis;
}
public class Point
{
public double X;
public double Y;
public double Z;
public int ID; // debug用
public override string ToString()
{
return $"({X},{Y},{Z},{ID})";
}
}
}