[BZOJ 2654]tree
APIO的时候听了一下wqs二分,做了这题,然后现在才写题解……
wqs二分的思路大抵就是你要求必须选\(k\)个,那就二分每个操作的一个“额外代价”,然后进行没有限制的选择。然后最后选出来的个数事和你二分的代价正相关/反相关的。
这道题的话,就二分选择黑边的代价,然后跑一般的最小生成树(有相同边权时要选择黑边!)。当然我们会遇到一个问题,就是二分到\(x\)的时候选的比\(k\)多,到\(x + 1\)的时候又比\(k\)少了。这道题的处理方法,事考虑代价为\(x\)时,其实存在选\(k\)个的最优解(如果说大于\(x\)那就没了),因此我们钦点代价为\(x\)跑一遍限制只能选\(k\)条黑边的最短路即可。
代码:
/************************************************************** Problem: 2654 User: danihao123 Language: C++ Result: Accepted Time:5756 ms Memory:5856 kb ****************************************************************/ #include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> #include <utility> #include <vector> const int maxn = 50005; const int maxm = 100005; struct Edge { int u, v, d; int typ; bool operator <(const Edge &res) const { if(d == res.d) { return typ < res.typ; } else { return d < res.d; } } }; Edge E[maxm]; int n, m, need; int par[maxn], siz[maxn]; void init_dsu() { for(int i = 1; i <= n; i ++) { par[i] = i; siz[i] = 1; } } int get_fa(int x) { if(par[x] == x) { return x; } else { return (par[x] = get_fa(par[x])); } } void link_set(int x, int y) { if(siz[x] > siz[y]) std::swap(x, y); siz[y] += siz[x]; par[x] = y; } void merge_set(int x, int y) { return link_set(get_fa(x), get_fa(y)); } bool is_same(int x, int y) { return (get_fa(x) == get_fa(y)); } typedef long long ll; int ans = 0x7fffffff; bool kruskal(int delta) { std::vector<Edge> vec; for(int i = 1; i <= m; i ++) { Edge e = E[i]; if(e.typ == 0) { e.d += delta; } vec.push_back(e); } std::sort(vec.begin(), vec.end()); int used = 0; ll tot = -(ll(delta)) * (ll(need)); init_dsu(); for(int i = 0; i < m; i ++) { const Edge &e = vec[i]; int u = e.u, v = e.v; if(!is_same(u, v)) { if(used == need && e.typ == 0) continue; merge_set(u, v); tot += e.d; if(e.typ == 0) used ++; } } if(used == need) { ans = std::min(ans, (int(tot))); return true; } else { return false; } } int main() { scanf("%d%d%d", &n, &m, &need); for(int i = 1; i <= m; i ++) { scanf("%d%d%d%d", &E[i].u, &E[i].v, &E[i].d, &E[i].typ); E[i].u ++; E[i].v ++; } int L = -10000001, R = 10000001; while(true) { #ifdef LOCAL printf("Range (%d, %d)\n", L, R); fflush(stdout); #endif if(R - L <= 3) { for(int i = R; i >= L; i --) { if(kruskal(i)) { break; } } break; } int M = L + (R - L) / 2; if(kruskal(M)) { L = M; } else { R = M; } } printf("%d\n", ans); return 0; }