欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

kdTree算法实现(C++代码实现)

程序员文章站 2022-07-14 14:39:52
...

kdTree算法实现(C++代码实现)

VS中需要安装Eigen包,可自行百度下载,添加至属性——包含项目即可。

kdTree.h文件

#ifndef KD_TREE_H
#define KD_TREE_H

#define lson (rt << 1)//左节点
#define rson (rt << 1 | 1)//右节点

#include
#include
#include <Eigen/Dense>
#include

using std::vector;
using Eigen::MatrixXf;

const int N = 50005;
const int k = 2; //2D-tree

struct Node {
int feature[2];//feature[0] = x, feature[1] = y
static int idx;
Node(int x0, int y0) {
feature[0] = x0;
feature[1] = y0;
}
bool operator < (const Node &u) const {
return feature[idx] < u.feature[idx];
}
//TOOD =hao
Node() {
feature[0] = 0;
feature[0] = 0;
}
};

class kdTree {
public:
kdTree();
~kdTree();
void clean();
void read_in(float* ary_x, float* ary_y, int len);
void read_in(vector points_array);
void build(int l, int r, int rt, int dept);
int find_nearest_point(float x, float y, Node &res, int& dist, int boundary);
int find_nearest_k_points(float x, float y, int k, vector &res, vector &dist, int boundary);
int distance(const Node& x, const Node& y);
private:
void query(const Node& p, Node& res, int& dist, int rt, int dept);
void query_k(const Node& p, int k, vector &res, vector &dist, int rt, int dept);
vector _data;//用vector模拟数组
vector _flag;//判断是否存在 1, -1, 0
int _idx;
vector _find_nth;
};

#endif

kdTree.cpp文件

#include "kdTree.h"

int Node::idx = 0;

kdTree::kdTree() 
{
	_data.reserve(N * 4);
	_flag.reserve(N * 4);//TODO init
}

kdTree::~kdTree() 
{
}

void kdTree::read_in(float* ary_x, float* ary_y, int len) 
{
	_find_nth.reserve(N * 4);
	for (int i = 0; i < len; i++) {
		_find_nth.push_back(Node(ary_x[i], ary_y[i]));
	}
	for (int i = 0; i < N * 4; i++) {
		Node tmp;
		_data.push_back(tmp);
		_flag.push_back(0);
	}
	build(0, len - 1, 1, 0);
}

void kdTree::read_in(vector<MatrixXf> points_array)
{
	int len = points_array.size();
	_find_nth.reserve(N * 4);
	for (int i = 0; i < len; i++) {
		_find_nth.push_back(Node(int(points_array[i](0, 0)), int(points_array[i](1, 0))));
	}
	for (int i = 0; i < N * 4; i++) {
		Node tmp;
		_data.push_back(tmp);
		_flag.push_back(0);
	}
	build(0, len - 1, 1, 0);
}

void kdTree::clean() 
{
	_find_nth.clear();
	_data.clear();
	_flag.clear();
}

//建立kd-tree
void kdTree::build(int l, int r, int rt, int dept) 
{
	if (l > r) return;
	_flag[rt] = 1;                  //表示标号为rt的节点存在
	_flag[lson] = _flag[rson] = -1; //当前节点的孩子暂时标记不存在 
	int mid = (l + r + 1) >> 1;
	Node::idx = dept % k;           //按照编号为idx的属性进行划分
	std::nth_element(_find_nth.begin() + l, _find_nth.begin() + mid, _find_nth.begin() + r + 1);
	_data[rt] = _find_nth[mid];
	build(l, mid - 1, lson, dept + 1); //递归左子树
	build(mid + 1, r, rson, dept + 1);
}

int kdTree::find_nearest_point(float x, float y, Node &res, int& dist, int boundary) 
{
	Node p(x, y);
	dist = boundary * boundary;
	query(p, res, dist, 1, 0);
	return 1;
}

int kdTree::find_nearest_k_points(float x, float y, int k, vector<Node>& res, vector<int>& dist, int boundary)
{
	vector<int> idx;
	idx.clear();
	res.clear();
	dist.clear();

	Node p(x, y);
	for (int i = 0; i < k; i++) {
		idx.push_back(-1);
		dist.push_back(boundary * boundary);
		query_k(p, i, idx, dist, 1, 0);
		if (idx[i] < 0) break;
		_flag[idx[i]] = 0;
		res.push_back(_data[idx[i]]);
	}

	int kfind = res.size();

	for (int i = 0; i < kfind; i++) {
		_flag[idx[i]] = 1;
	}

	return kfind;
}

//查找kd-tree距离p最近的点
void kdTree::query(const Node& p, Node& res, int& dist, int rt, int dept) 
{
	if (_flag[rt] == -1) return; //不存在的节点不遍历
	int tmp_dist = distance(_data[rt], p);
	bool fg = false; //用于标记是否需要遍历右子树
	int dim = dept % k; //和建树一样, 保证相同节点的dim值不变
	int x = lson;
	int y = rson;

	if (p.feature[dim] >= _data[rt].feature[dim]) std::swap(x, y);  //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
	if (x < _flag.size() && _flag[x] != -1) query(p, res, dist, x, dept + 1); //节点x存在, 则进入子树继续遍历
	if (tmp_dist < dist) { //如果找到更小的距离, 则替换目前的结果dist
		res = _data[rt];
		dist = tmp_dist;
	}
	tmp_dist = (p.feature[dim] - _data[rt].feature[dim]) * (p.feature[dim] - _data[rt].feature[dim]);
	if (tmp_dist < dist) fg = true; //还需要继续回溯
	if (y < _flag.size() && _flag[y] != -1 && fg) query(p, res, dist, y, dept + 1);
}

void kdTree::query_k(const Node & p, int i, vector<int>& res, vector<int>& dist, int rt, int dept)
{
	if (_flag[rt] == -1) return; //不存在的节点不遍历
	int tmp_dist = distance(_data[rt], p);
	bool fg = false; //用于标记是否需要遍历右子树
	int dim = dept % k; //和建树一样, 保证相同节点的dim值不变
	int x = lson;
	int y = rson;

	if (p.feature[dim] >= _data[rt].feature[dim]) std::swap(x, y);  //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
	if (x<_flag.size() && _flag[x]==1) query_k(p, i, res, dist, x, dept + 1); //节点x存在, 则进入子树继续遍历
	if (tmp_dist < dist[i] && _flag[rt] == 1) { //如果找到更小的距离, 则替换目前的结果dist
		res[i] = rt;
		dist[i] = tmp_dist;
	}
	tmp_dist = (p.feature[dim] - _data[rt].feature[dim]) * (p.feature[dim] - _data[rt].feature[dim]);
	if (tmp_dist < dist[i]) fg = true; //还需要继续回溯
	if (y < _flag.size() && _flag[y] == 1 && fg) query_k(p, i, res, dist, y, dept + 1);
}

//计算两点间的距离的平方
int kdTree::distance(const Node& x, const Node& y) 
{
	int res = 0;
	for (int i = 0; i < k; i++) 
	{
		res += (x.feature[i] - y.feature[i]) * (x.feature[i] - y.feature[i]);
	}
	return res;
}

kdTree可实现实现IDW(反距离插值算法),有时间再写上IDW算法的实现。