欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

HDU - 6447 YJJ's Salesman (线段树优化dp)

程序员文章站 2022-06-09 17:42:15
...

http://acm.hdu.edu.cn/showproblem.php?pid=6447

题意:

给你几个坐标,每个坐标都有其财富值,你每次从(0,0)开始走,只能向右,或向下或者右下走,有财富的坐标只能从左上方进入,问最多的财富为多少

思路:

第一眼看到这道题感觉是dp,但是dp的话,数组太大,开不了。于是就离散化,用线段树优化dp
对于一个点(x,y),它的最大值是(0->(x-1),0->(y-1))这个子矩阵的最大值,而用矩阵的话,数据太大了,于是我们可以用数组来记录前面这些列的最大值,因为是区间查找,所以用线段树最好,
举个例子:
对于样例:
1
3
1 1 1
1 2 2
3 3 1
一个dp[]数组记录每一列的当前位置可以到达的最大值
HDU - 6447 YJJ's Salesman (线段树优化dp)
先一列一列的来
第0列最大值为0,dp[]全为0
第1列易得dp[1]=1,dp[2]=2,其他全为0,
第3列同第0列,这是dp[1]=1,dp[2]=2
而到了第4列dp[3] = max(dp[1], dp[2]) + mp[3][3],最大值为dp[3]=3
最后的dp为:dp[0]=0,dp[1]=1,dp[2]=2,dp[3]=3
就是线段数记录的是在经过前面n列后每个位置可以到达的最大值,最后在遍历一下dp,找最大值即可

AC代码(从0开始)

#include <iostream>
#include <string.h>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <stdio.h>
#include <deque>

#define M  1000000007
using namespace std;
const int maxn = 2e5 + 5;

struct Point{
    int x, y, w;
    friend bool operator < (Point a, Point b) {//排序以x(列)从小到大排序,x相同时按照y(行)从大到小排序,参照01背包
        if(a.x == b.x) return a.y > b.y;
        return a.x < b.x;
    }
}point[maxn];

int pointx[maxn], pointy[maxn], dp[maxn];
int p, L, R, w, ans;

struct Tree{
    int l, r, w;
}tree[maxn << 2];

void build(int k, int ll, int rr) {//建树
    tree[k].l = ll; tree[k].r = rr;
    if(tree[k].l == tree[k].r) {
        tree[k].w = 0;
        return ;
    }
    int mm = (ll + rr) / 2;
    build(k * 2, ll, mm);
    build(k * 2 + 1, mm + 1, rr);
    tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}

void Point_Change(int k) {//单点修改
    if(tree[k].l == tree[k].r) {
        tree[k].w = w;
        return ;
    }
    int mm = (tree[k].l + tree[k].r) / 2;
    if(p <= mm) Point_Change(k * 2);
    else Point_Change(k * 2 + 1);
    tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}

void Interval_Ask(int k) {//区间查询
    if(tree[k].l >= L && tree[k].r <= R) {
        ans = max(ans, tree[k].w);
        return ;
    }
    int mm = (tree[k].l + tree[k].r) / 2;
    if(L <= mm) Interval_Ask(k * 2);
    if(R > mm) Interval_Ask(k * 2 + 1);
}


int main() {
    int T;
    scanf("%d", &T);
    while(T --) {
        int n;
        scanf("%d", &n);

        for (int i = 0; i < n; i ++) {
            scanf("%d%d%d", &point[i].x, &point[i].y, &point[i].w);
            pointx[i] = point[i].x; pointy[i] = point[i].y;
        }
        //---离散化开始----
        pointx[n] = 0; pointy[n] = 0;//加入(0,0)
        sort(pointx, pointx + n + 1); sort(pointy, pointy + n + 1);
        int px = unique(pointx, pointx + n + 1) - pointx;//x不重复有多少个
        int py = unique(pointy, pointy + n + 1) - pointy;//y不重复有多少个
        for(int i = 0; i < n; i ++) {
            point[i].x = lower_bound(pointx, pointx + px, point[i].x) - pointx;//给点重新编号,
            point[i].y = lower_bound(pointy, pointy + py, point[i].y) - pointy;//从0开始
        }
        // ----离散化结束-----
        sort(point, point + n);
        memset(dp, 0, sizeof(dp));
        build(1, 0, py);
        for (int i = 0; i < n; i ++) {
            int kk = point[i].x, j;
            for (j = i; j < n; j ++) {
                if(point[j].x != kk) break;
                L = 0; R = point[j].y - 1;
                ans = 0;
                Interval_Ask(1);
                int tmp = ans + point[j].w;
                if(tmp > dp[point[j].y]) {
                    dp[point[j].y] = tmp;
                    p = point[j].y;
                    w = dp[point[j].y];
                    Point_Change(1);
                }
            }
            i = j - 1;
        }
        int res = 0;
        for (int i = 0; i <= py; i ++)
            res = max(res, dp[i]);
        printf("%d\n", res);
    }
    return 0;
}

从0开始和从1开始基本一样,下面贴上从1开始的代码

AC代码(从1开始)

#include <iostream>
#include <string.h>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <stdio.h>
#include <deque>
using namespace std;
const int maxn = 1e5 + 5;

struct Point{
    int x, y, w;
    friend bool operator < (Point a, Point b) {
        if(a.x == b.x) return a.y > b.y;
        return a.x < b.x;
    }
}point[maxn];

int pointx[maxn], pointy[maxn], dp[maxn];
int p, L, R, w, ans;

struct Tree{
    int l, r, w;
}tree[maxn << 2];

void build(int k, int ll, int rr) {//建树
    tree[k].l = ll; tree[k].r = rr;
    if(tree[k].l == tree[k].r) {
        tree[k].w = 0;
        return ;
    }
    int mm = (ll + rr) / 2;
    build(k * 2, ll, mm);
    build(k * 2 + 1, mm + 1, rr);
    tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}

void Point_Change(int k) {//单点修改
    if(tree[k].l == tree[k].r) {
        tree[k].w = w;
        return ;
    }
    int mm = (tree[k].l + tree[k].r) / 2;
    if(p <= mm) Point_Change(k * 2);
    else Point_Change(k * 2 + 1);
    tree[k].w = max(tree[k * 2].w, tree[k * 2 + 1].w);
}

void Interval_Ask(int k) {//区间查询
    if(tree[k].l >= L && tree[k].r <= R) {
        ans = max(ans, tree[k].w);
        return ;
    }
    int mm = (tree[k].l + tree[k].r) / 2;
    if(L <= mm) Interval_Ask(k * 2);
    if(R > mm) Interval_Ask(k * 2 + 1);
}


int main() {
    int T;
    scanf("%d", &T);
    while(T --) {
        int n;
        scanf("%d", &n);
        memset(dp, 0, sizeof(dp));
        for (int i = 0; i < n; i ++) {
            scanf("%d%d%d", &point[i].x, &point[i].y, &point[i].w);
            pointx[i] = point[i].x; pointy[i] = point[i].y;
        }
        pointx[n] = 0; pointy[n] = 0;//加入(0,0)
        sort(pointx, pointx + n + 1); sort(pointy, pointy + n + 1);
        int px = unique(pointx, pointx + n + 1) - pointx;//x不重复有多少个
        int py = unique(pointy, pointy + n + 1) - pointy;//y不重复有多少个
        for(int i = 0; i < n; i ++) {
            point[i].x = lower_bound(pointx, pointx + px, point[i].x) - pointx + 1;//给点重新编号,
            point[i].y = lower_bound(pointy, pointy + py, point[i].y) - pointy + 1;//从1开始
        }
        // 离散化
        sort(point, point + n);
        build(1, 1, py + 1);
        for (int i = 0; i < n; i ++) {
            int kk = point[i].x, j;
            for (j = i; j < n; j ++) {
                if(point[j].x != kk) break;
                L = 1, R = point[j].y - 1;
                ans = 0;
                Interval_Ask(1);
                int tmp = ans + point[j].w;
                if(tmp > dp[point[j].y]) {
                    dp[point[j].y] = tmp;
                    p = point[j].y;
                    w = dp[point[j].y];
                    Point_Change(1);
                }
            }
            i = j - 1;
        }
        int res = 0;
        for (int i = 0; i <= py + 1; i ++)
            res = max(res, dp[i]);
        printf("%d\n", res);
    }
    return 0;
}