[BZOJ 2654]tree

APIO的时候听了一下wqs二分,做了这题,然后现在才写题解……

wqs二分的思路大抵就是你要求必须选\(k\)个,那就二分每个操作的一个“额外代价”,然后进行没有限制的选择。然后最后选出来的个数事和你二分的代价正相关/反相关的。

这道题的话,就二分选择黑边的代价,然后跑一般的最小生成树(有相同边权时要选择黑边!)。当然我们会遇到一个问题,就是二分到\(x\)的时候选的比\(k\)多,到\(x + 1\)的时候又比\(k\)少了。这道题的处理方法,事考虑代价为\(x\)时,其实存在选\(k\)个的最优解(如果说大于\(x\)那就没了),因此我们钦点代价为\(x\)跑一遍限制只能选\(k\)条黑边的最短路即可。

代码:

/**************************************************************
    Problem: 2654
    User: danihao123
    Language: C++
    Result: Accepted
    Time:5756 ms
    Memory:5856 kb
****************************************************************/
 
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <utility>
#include <vector>
const int maxn = 50005;
const int maxm = 100005;
struct Edge {
  int u, v, d;
  int typ;
  bool operator <(const Edge &res) const {
    if(d == res.d) {
      return typ < res.typ;
    } else {
      return d < res.d;
    }
  }
};
 
Edge E[maxm];
int n, m, need;
 
int par[maxn], siz[maxn];
void init_dsu() {
  for(int i = 1; i <= n; i ++) {
    par[i] = i;
    siz[i] = 1;
  }
}
int get_fa(int x) {
  if(par[x] == x) {
    return x;
  } else {
    return (par[x] = get_fa(par[x]));
  }
}
void link_set(int x, int y) {
  if(siz[x] > siz[y]) std::swap(x, y);
  siz[y] += siz[x];
  par[x] = y;
}
void merge_set(int x, int y) {
  return link_set(get_fa(x), get_fa(y));
}
bool is_same(int x, int y) {
  return (get_fa(x) == get_fa(y));
}
 
typedef long long ll;
int ans = 0x7fffffff;
bool kruskal(int delta) {
  std::vector<Edge> vec;
  for(int i = 1; i <= m; i ++) {
    Edge e = E[i];
    if(e.typ == 0) {
      e.d += delta;
    }
    vec.push_back(e);
  }
  std::sort(vec.begin(), vec.end());
  int used = 0; ll tot = -(ll(delta)) * (ll(need));
  init_dsu();
  for(int i = 0; i < m; i ++) {
    const Edge &e = vec[i];
    int u = e.u, v = e.v;
    if(!is_same(u, v)) {
      if(used == need && e.typ == 0) continue;
      merge_set(u, v); tot += e.d;
      if(e.typ == 0) used ++;
    }
  }
  if(used == need) {
    ans = std::min(ans, (int(tot)));
    return true;
  } else {
    return false;
  }
}
 
int main() {
  scanf("%d%d%d", &n, &m, &need);
  for(int i = 1; i <= m; i ++) {
    scanf("%d%d%d%d", &E[i].u, &E[i].v, &E[i].d, &E[i].typ);
    E[i].u ++; E[i].v ++;
  }
  int L = -10000001, R = 10000001;
  while(true) {
#ifdef LOCAL
    printf("Range (%d, %d)\n", L, R);
    fflush(stdout);
#endif
    if(R - L <= 3) {
      for(int i = R; i >= L; i --) {
        if(kruskal(i)) {
          break;
        }
      }
      break;
    }
    int M = L + (R - L) / 2;
    if(kruskal(M)) {
      L = M;
    } else {
      R = M;
    }
  }
  printf("%d\n", ans);
  return 0;
}