[BZOJ 2876][NOI2012]骑行川藏

我终于A了……不就是拉格朗日乘数法的模板题吗

首先这道题最优情况下一定有(这里用\(E_i\)表示第\(i\)段路程的耗能):\(\sum_{i = 1}^n E_i = E_u\)。

然后这个东西是一个等式限制条件,然后我们还要最小化总用时,给人拉格朗日乘数法的即视感……

不管怎么说让我们来列式子吧:

\[h(x_1, x_2,\ldots ,x_n, \lambda) = \sum_{i = 1}^n \frac{s_i}{x_i} + \lambda \sum_{i = 1}^n k_i s_i (x_i - v_i)^2\]

\[\frac{\partial h}{\partial x_i} = -\frac{s_i}{x_i^2} + 2\lambda k_i s_i (x_i - v_i)\]

然后你会想这TM怎么解方程……

但是我们想一想,把\(\frac{\partial h}{\partial x_i} = 0\)稍作整理,得:

\[\frac{1}{2k_i\lambda} = x_i^2 (x_i - v_i)\]

对于式子的左边,是一个关于\(x_i\)的增函数(因为这道题默认了\(x_i\geq 0\)且\(x_i\geq v_i\)),然后不妨令\(\frac{1}{\lambda} = u\),可以发现\(u\)越大则\(x_i\)越大!这样一来那么我们的限制条件就会更加难以满足。

所以我们可以二分这个\(u\)来解这些方程。

BTW,这题卡精度非常厉害……

代码:

/**************************************************************
    Problem: 2876
    User: danihao123
    Language: C++
    Result: Accepted
    Time:3960 ms
    Memory:1524 kb
****************************************************************/
 
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <cassert>
typedef double R;
const R eps = 1e-12;
const int maxn = 10005;
int sign(R x) {
  if(fabs(x) < eps) {
    return 0;
  } else {
    if(x < 0) {
      return -1;
    } else {
      return 1;
    }
  }
}
R s[maxn], k[maxn], V[maxn];
R rf(int i, R v, R delta) {
  return v * v * (v - V[i]) - delta;
}
R rf2(int i, R v) {
  return 3 * v * v - 2 * v * V[i];
}
R gen_rt(int i, R delta) {
  R x = 1e6;
  int lambda = 100;
  while(lambda --) {
    R a = rf(i, x, delta);
    if(sign(a) == 0) break;
    R da = rf2(i, x);
    x -= a / da;
  }
  return x;
}
int n; R E;
R gen_lft() {
  R l = 1e-14, r = 1e16;
  while(r - l > eps) {
#ifdef LOCAL
    printf("State [%.18lf, %.18lf]\n", l, r);
#endif
    R M = (l + r) / 2;
    R T = 0;
    for(int i = 1; i <= n; i ++) {
      R v = gen_rt(i, M / (2 * k[i]));
      T += (v - V[i]) * (v - V[i]) * k[i] * s[i];
    }
    if(sign(T - E) <= 0) {
      l = M;
    } else {
      r = M;
    }
  }
  return l;
}
 
int main() {
  std::cin >> n >> E;
  for(int i = 1; i <= n; i ++) {
    std::cin >> s[i] >> k[i] >> V[i];
  }
  R M = gen_lft();
  R tm = 0;
  for(int i = 1; i <= n; i ++) {
    R v = gen_rt(i, M / (2 * k[i]));
    tm += s[i] / v;
  }
  printf("%.8lf\n", tm);
  return 0;
}