HDU - 6447 YJJ's Salesman (线段树优化dp)
程序员文章站
2022-06-09 17:42:15
...
http://acm.hdu.edu.cn/showproblem.php?pid=6447
题意:
给你几个坐标,每个坐标都有其财富值,你每次从(0,0)开始走,只能向右,或向下或者右下走,有财富的坐标只能从左上方进入,问最多的财富为多少
思路:
第一眼看到这道题感觉是dp,但是dp的话,数组太大,开不了。于是就离散化,用线段树优化dp
对于一个点(x,y),它的最大值是(0->(x-1),0->(y-1))这个子矩阵的最大值,而用矩阵的话,数据太大了,于是我们可以用数组来记录前面这些列的最大值,因为是区间查找,所以用线段树最好,
举个例子:
对于样例:
1
3
1 1 1
1 2 2
3 3 1
一个dp[]数组记录每一列的当前位置可以到达的最大值
先一列一列的来
第0列最大值为0,dp[]全为0
第1列易得dp[1]=1,dp[2]=2,其他全为0,
第3列同第0列,这是dp[1]=1,dp[2]=2
而到了第4列dp[3] = max(dp[1], dp[2]) + mp[3][3],最大值为dp[3]=3
最后的dp为:dp[0]=0,dp[1]=1,dp[2]=2,dp[3]=3
就是线段数记录的是在经过前面n列后每个位置可以到达的最大值,最后在遍历一下dp,找最大值即可
AC代码(从0开始)
#include <iostream>
#include <string.h>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <stdio.h>
#include <deque>
#define M 1000000007
using namespace std;
const int maxn = 2e5 + 5;
struct Point{
int x, y, w;
friend bool operator < (Point a, Point b) {//排序以x(列)从小到大排序,x相同时按照y(行)从大到小排序,参照01背包
if(a.x == b.x) return a.y > b.y;
return a.x < b.x;
}
}point[maxn];
int pointx[maxn], pointy[maxn], dp[maxn];
int p, L, R, w, ans;
struct Tree{
int l, r, w;
}tree[maxn << 2];
void build(int k, int ll, int rr) {//建树
tree[k].l = ll; tree[k].r = rr;
if(tree[k].l == tree[k].r) {
tree[k].w = 0;
return ;
}
int mm = (ll + rr) / 2;
build(k * 2, ll, mm);
build(k * 2 + 1, mm + 1, rr);
tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}
void Point_Change(int k) {//单点修改
if(tree[k].l == tree[k].r) {
tree[k].w = w;
return ;
}
int mm = (tree[k].l + tree[k].r) / 2;
if(p <= mm) Point_Change(k * 2);
else Point_Change(k * 2 + 1);
tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}
void Interval_Ask(int k) {//区间查询
if(tree[k].l >= L && tree[k].r <= R) {
ans = max(ans, tree[k].w);
return ;
}
int mm = (tree[k].l + tree[k].r) / 2;
if(L <= mm) Interval_Ask(k * 2);
if(R > mm) Interval_Ask(k * 2 + 1);
}
int main() {
int T;
scanf("%d", &T);
while(T --) {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i ++) {
scanf("%d%d%d", &point[i].x, &point[i].y, &point[i].w);
pointx[i] = point[i].x; pointy[i] = point[i].y;
}
//---离散化开始----
pointx[n] = 0; pointy[n] = 0;//加入(0,0)
sort(pointx, pointx + n + 1); sort(pointy, pointy + n + 1);
int px = unique(pointx, pointx + n + 1) - pointx;//x不重复有多少个
int py = unique(pointy, pointy + n + 1) - pointy;//y不重复有多少个
for(int i = 0; i < n; i ++) {
point[i].x = lower_bound(pointx, pointx + px, point[i].x) - pointx;//给点重新编号,
point[i].y = lower_bound(pointy, pointy + py, point[i].y) - pointy;//从0开始
}
// ----离散化结束-----
sort(point, point + n);
memset(dp, 0, sizeof(dp));
build(1, 0, py);
for (int i = 0; i < n; i ++) {
int kk = point[i].x, j;
for (j = i; j < n; j ++) {
if(point[j].x != kk) break;
L = 0; R = point[j].y - 1;
ans = 0;
Interval_Ask(1);
int tmp = ans + point[j].w;
if(tmp > dp[point[j].y]) {
dp[point[j].y] = tmp;
p = point[j].y;
w = dp[point[j].y];
Point_Change(1);
}
}
i = j - 1;
}
int res = 0;
for (int i = 0; i <= py; i ++)
res = max(res, dp[i]);
printf("%d\n", res);
}
return 0;
}
从0开始和从1开始基本一样,下面贴上从1开始的代码
AC代码(从1开始)
#include <iostream>
#include <string.h>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <stdio.h>
#include <deque>
using namespace std;
const int maxn = 1e5 + 5;
struct Point{
int x, y, w;
friend bool operator < (Point a, Point b) {
if(a.x == b.x) return a.y > b.y;
return a.x < b.x;
}
}point[maxn];
int pointx[maxn], pointy[maxn], dp[maxn];
int p, L, R, w, ans;
struct Tree{
int l, r, w;
}tree[maxn << 2];
void build(int k, int ll, int rr) {//建树
tree[k].l = ll; tree[k].r = rr;
if(tree[k].l == tree[k].r) {
tree[k].w = 0;
return ;
}
int mm = (ll + rr) / 2;
build(k * 2, ll, mm);
build(k * 2 + 1, mm + 1, rr);
tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}
void Point_Change(int k) {//单点修改
if(tree[k].l == tree[k].r) {
tree[k].w = w;
return ;
}
int mm = (tree[k].l + tree[k].r) / 2;
if(p <= mm) Point_Change(k * 2);
else Point_Change(k * 2 + 1);
tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}
void Interval_Ask(int k) {//区间查询
if(tree[k].l >= L && tree[k].r <= R) {
ans = max(ans, tree[k].w);
return ;
}
int mm = (tree[k].l + tree[k].r) / 2;
if(L <= mm) Interval_Ask(k * 2);
if(R > mm) Interval_Ask(k * 2 + 1);
}
int main() {
int T;
scanf("%d", &T);
while(T --) {
int n;
scanf("%d", &n);
memset(dp, 0, sizeof(dp));
for (int i = 0; i < n; i ++) {
scanf("%d%d%d", &point[i].x, &point[i].y, &point[i].w);
pointx[i] = point[i].x; pointy[i] = point[i].y;
}
pointx[n] = 0; pointy[n] = 0;//加入(0,0)
sort(pointx, pointx + n + 1); sort(pointy, pointy + n + 1);
int px = unique(pointx, pointx + n + 1) - pointx;//x不重复有多少个
int py = unique(pointy, pointy + n + 1) - pointy;//y不重复有多少个
for(int i = 0; i < n; i ++) {
point[i].x = lower_bound(pointx, pointx + px, point[i].x) - pointx + 1;//给点重新编号,
point[i].y = lower_bound(pointy, pointy + py, point[i].y) - pointy + 1;//从1开始
}
// 离散化
sort(point, point + n);
build(1, 1, py + 1);
for (int i = 0; i < n; i ++) {
int kk = point[i].x, j;
for (j = i; j < n; j ++) {
if(point[j].x != kk) break;
L = 1, R = point[j].y - 1;
ans = 0;
Interval_Ask(1);
int tmp = ans + point[j].w;
if(tmp > dp[point[j].y]) {
dp[point[j].y] = tmp;
p = point[j].y;
w = dp[point[j].y];
Point_Change(1);
}
}
i = j - 1;
}
int res = 0;
for (int i = 0; i <= py + 1; i ++)
res = max(res, dp[i]);
printf("%d\n", res);
}
return 0;
}