[BZOJ 1010][HNOI2008]玩具装箱toy
很久之前是学过并写过斜率优化的……但是很快就忘了。现在感觉自己理解了,感觉是真的懂了……抽空写篇文章解释一下吧……
先单独说这一个题。将DP方程完全展开,并且设\(P_i = S_i + i\),\(c = L + 1\),可得:
\[f_i = c^2 + P_i^2 - 2P_i c + max(P_j^2 + 2P_j c + f_j - 2P_i P_j)\]
然后\(c^2 + P_i^2 - 2P_i c\)这部分是常数项不需要管了,我们就想想max里面那些(姑且设之为\(d_i\))咋整好了。
设\(d_i = P_j^2 + 2P_j c + f_j - 2P_i P_j\),稍作移项,得:
\[2P_i P_j + d_i = P_j^2 + 2P_j c + f_j\]
于是乎,\(d_i\)可以看做斜率为\(2P_i\)的直线过点\((P_j, P_j^2 + 2P_j c + f_j)\)得到的截距。而那些点我们之前都知道了,问题就变成了已知斜率,求过某点集中的点的最大截距。
想象一个固定斜率的直线从下往上扫,那么碰到的第一个点就是最优解。首先这个点一定在下凸壳上,其次下凸壳上这点两侧的线段的斜率肯定一个比\(2P_i\)大另一个比它小。并且最好的一点是这个斜率还是单调的,那么分界点一定是单调递增的。
代码:
/************************************************************** Problem: 1010 User: danihao123 Language: C++ Result: Accepted Time:132 ms Memory:2416 kb ****************************************************************/ #include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> #include <utility> #include <deque> #include <cmath> typedef long long ll; typedef ll T; struct Point { T x, y; Point(T qx = 0LL, T qy = 0LL) { x = qx; y = qy; } }; typedef Point Vector; Vector operator +(const Vector &a, const Vector &b) { return Vector(a.x + b.x, a.y + b.y); } Vector operator -(const Point &a, const Point &b) { return Vector(a.x - b.x, a.y - b.y); } Vector operator *(const Vector &a, T lam) { return Vector(a.x * lam, a.y * lam); } Vector operator *(T lam, const Vector &a) { return Vector(a.x * lam, a.y * lam); } inline T dot(const Vector &a, const Vector &b) { return (a.x * b.x + a.y * b.y); } inline T times(const Vector &a, const Vector &b) { return (a.x * b.y - a.y * b.x); } const int maxn = 50005; T C[maxn], S[maxn], P[maxn]; T f[maxn]; int n; ll c; void process() { for(int i = 1; i <= n; i ++) { S[i] = S[i - 1] + C[i]; P[i] = S[i] + (ll(i)); } } void dp() { std::deque<Point> Q; Q.push_back(Point(0LL, 0LL)); for(int i = 1; i <= n; i ++) { ll k = 2 * P[i]; Vector st(1, k); while(Q.size() > 1 && times(Q[1] - Q[0], st) > 0LL) { Q.pop_front(); } f[i] = c * c + P[i] * P[i] - 2LL * P[i] * c; f[i] += Q.front().y - k * Q.front().x; #ifdef LOCAL printf("f[%d] : %lld\n", i, f[i]); #endif Vector ins(P[i], f[i] + P[i] * P[i] + 2LL * P[i] * c); while(Q.size() > 1 && times(ins - Q.back(), Q.back() - Q[Q.size() - 2]) > 0LL) { #ifdef LOCAL printf("Deleting (%lld, %lld)...\n", Q.back().x, Q.back().y); #endif Q.pop_back(); } Q.push_back(ins); #ifdef LOCAL printf("Inserting (%lld, %lld)...\n", ins.x, ins.y); #endif } } int main() { scanf("%d%lld", &n, &c); c ++; for(int i = 1; i <= n; i ++) { scanf("%lld", &C[i]); } process(); dp(); printf("%lld\n", f[n]); return 0; }