[BZOJ 2654]tree
APIO的时候听了一下wqs二分,做了这题,然后现在才写题解……
wqs二分的思路大抵就是你要求必须选\(k\)个,那就二分每个操作的一个“额外代价”,然后进行没有限制的选择。然后最后选出来的个数事和你二分的代价正相关/反相关的。
这道题的话,就二分选择黑边的代价,然后跑一般的最小生成树(有相同边权时要选择黑边!)。当然我们会遇到一个问题,就是二分到\(x\)的时候选的比\(k\)多,到\(x + 1\)的时候又比\(k\)少了。这道题的处理方法,事考虑代价为\(x\)时,其实存在选\(k\)个的最优解(如果说大于\(x\)那就没了),因此我们钦点代价为\(x\)跑一遍限制只能选\(k\)条黑边的最短路即可。
代码:
/**************************************************************
Problem: 2654
User: danihao123
Language: C++
Result: Accepted
Time:5756 ms
Memory:5856 kb
****************************************************************/
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <utility>
#include <vector>
const int maxn = 50005;
const int maxm = 100005;
struct Edge {
int u, v, d;
int typ;
bool operator <(const Edge &res) const {
if(d == res.d) {
return typ < res.typ;
} else {
return d < res.d;
}
}
};
Edge E[maxm];
int n, m, need;
int par[maxn], siz[maxn];
void init_dsu() {
for(int i = 1; i <= n; i ++) {
par[i] = i;
siz[i] = 1;
}
}
int get_fa(int x) {
if(par[x] == x) {
return x;
} else {
return (par[x] = get_fa(par[x]));
}
}
void link_set(int x, int y) {
if(siz[x] > siz[y]) std::swap(x, y);
siz[y] += siz[x];
par[x] = y;
}
void merge_set(int x, int y) {
return link_set(get_fa(x), get_fa(y));
}
bool is_same(int x, int y) {
return (get_fa(x) == get_fa(y));
}
typedef long long ll;
int ans = 0x7fffffff;
bool kruskal(int delta) {
std::vector<Edge> vec;
for(int i = 1; i <= m; i ++) {
Edge e = E[i];
if(e.typ == 0) {
e.d += delta;
}
vec.push_back(e);
}
std::sort(vec.begin(), vec.end());
int used = 0; ll tot = -(ll(delta)) * (ll(need));
init_dsu();
for(int i = 0; i < m; i ++) {
const Edge &e = vec[i];
int u = e.u, v = e.v;
if(!is_same(u, v)) {
if(used == need && e.typ == 0) continue;
merge_set(u, v); tot += e.d;
if(e.typ == 0) used ++;
}
}
if(used == need) {
ans = std::min(ans, (int(tot)));
return true;
} else {
return false;
}
}
int main() {
scanf("%d%d%d", &n, &m, &need);
for(int i = 1; i <= m; i ++) {
scanf("%d%d%d%d", &E[i].u, &E[i].v, &E[i].d, &E[i].typ);
E[i].u ++; E[i].v ++;
}
int L = -10000001, R = 10000001;
while(true) {
#ifdef LOCAL
printf("Range (%d, %d)\n", L, R);
fflush(stdout);
#endif
if(R - L <= 3) {
for(int i = R; i >= L; i --) {
if(kruskal(i)) {
break;
}
}
break;
}
int M = L + (R - L) / 2;
if(kruskal(M)) {
L = M;
} else {
R = M;
}
}
printf("%d\n", ans);
return 0;
}