[LibreOJ 6433][PKUSC2018]最大前缀和
转载请注明出处:http://danihao123.is-programmer.com/
给你一个长为\(n\)的整数序列\(a\),求出将序列随机打乱之后的最大前缀和(不能选空前缀!)的期望,答案乘上\(n!\)之后对\(998244353\)输出。
\(1\leq n\leq 20, \sum_{i = 1}^n |a_i|\leq 10^9\)。
PKUSC……棋差一招吧……其实不就是我太弱了吗
真的是NOIP题吧……
考虑对于一个\(a\)的排列\(A\),假设前缀\(1\ldots i\)是最大前缀(这里加设多解的情况取\(i\)最小的),那么很显然\(A[1\ldots i]\)的每一个后缀(除了其本身)都是大于\(0\)的(下称性质1),而\(A[i + 1\ldots n]\)的所有前缀都是小于等于\(0\)的(下称性质2)。
因此设计状态\(f_s\)表示用集合\(s\)可以凑出来的满足性质1的排列数量,\(g_s\)表示满足性质2的数量,这两个东西都很好转移。然后我们枚举一下最大前缀和的集合,然后利用上述状态计算一下就行了。
#include <cstdio> #include <cstring> #include <cstdlib> #include <algorithm> #include <utility> using ll = long long; const int ha = 998244353; const int maxn = 20; int ma[(1 << maxn)], A[maxn], n; void process_ma() { for(int i = 0; i < maxn; i ++) { ma[(1 << i)] = i; } } int lowbit(int x) { return x & (-x); } const int INF = 0x7f7f7f7f; int sum[1 << maxn]; int calc_sum(int S) { if(S == 0) return 0; if(sum[S] != INF) return sum[S]; int low = lowbit(S); sum[S] = calc_sum(S ^ low) + A[ma[low]]; return sum[S]; } int f[1 << maxn]; // > 0 int calc_f(int S) { if(S == 0) return 1; int lb = lowbit(S); // if(S == lb && calc_sum(S) <= 0) return 0; if(f[S] != INF) return f[S]; f[S] = 0; for(int i = 0; i < n; i ++) { if(((1 << i) & S) != 0 && (S == lb || calc_sum(S ^ (1 << i)) > 0)) { f[S] = (f[S] + calc_f(S ^ (1 << i))) % ha; } } return f[S]; } int g[1 << maxn]; // <= 0 int calc_g(int S) { if(S == 0) return 1; if(calc_sum(S) > 0) return 0; if(g[S] != INF) return g[S]; g[S] = 0; for(int i = 0; i < n; i ++) { if((1 << i) & S) { g[S] = (g[S] + calc_g(S ^ (1 << i))) % ha; } } return g[S]; } int main() { process_ma(); memset(sum, 0x7f, sizeof(sum)); memset(f, 0x7f, sizeof(f)); memset(g, 0x7f, sizeof(g)); scanf("%d", &n); for(int i = 0; i < n; i ++) { scanf("%d", &A[i]); } int ans = 0; int al = (1 << n) - 1; for(int s = 0; s < (1 << n); s ++) { ll fv = calc_f(s), gv = calc_g(al ^ s); #ifdef LOCAL printf("f[%d] : %lld\n", s, fv); #endif fv = (fv * gv) % (ll(ha)); ll S = calc_sum(s); while(S < 0) S += ha; S = S % (ll(ha)); fv = (fv * S) % (ll(ha)); ans = (ans + (int(fv))) % ha; #ifdef LOCAL printf("delta[%d] : %lld\n", s, fv); #endif } printf("%d\n", ans); return 0; }