[LibreOJ 6433][PKUSC2018]最大前缀和

danihao123 posted @ 2018年9月03日 16:12 in 题解 with tags loj PKUSC 状压dp , 1251 阅读
转载请注明出处:http://danihao123.is-programmer.com/

给你一个长为\(n\)的整数序列\(a\),求出将序列随机打乱之后的最大前缀和(不能选空前缀!)的期望,答案乘上\(n!\)之后对\(998244353\)输出。

\(1\leq n\leq 20, \sum_{i = 1}^n |a_i|\leq 10^9\)。

PKUSC……棋差一招吧……其实不就是我太弱了吗

真的是NOIP题吧……

考虑对于一个\(a\)的排列\(A\),假设前缀\(1\ldots i\)是最大前缀(这里加设多解的情况取\(i\)最小的),那么很显然\(A[1\ldots i]\)的每一个后缀(除了其本身)都是大于\(0\)的(下称性质1),而\(A[i + 1\ldots n]\)的所有前缀都是小于等于\(0\)的(下称性质2)。

因此设计状态\(f_s\)表示用集合\(s\)可以凑出来的满足性质1的排列数量,\(g_s\)表示满足性质2的数量,这两个东西都很好转移。然后我们枚举一下最大前缀和的集合,然后利用上述状态计算一下就行了。

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <utility>
using ll = long long;
const int ha = 998244353;
const int maxn = 20;
int ma[(1 << maxn)], A[maxn], n;
void process_ma() {
  for(int i = 0; i < maxn; i ++) {
    ma[(1 << i)] = i;
  }
}
int lowbit(int x) {
  return x & (-x);
}

const int INF = 0x7f7f7f7f;
int sum[1 << maxn];
int calc_sum(int S) {
  if(S == 0) return 0;
  if(sum[S] != INF) return sum[S];
  int low = lowbit(S);
  sum[S] = calc_sum(S ^ low) + A[ma[low]];
  return sum[S];
}
int f[1 << maxn]; // > 0
int calc_f(int S) {
  if(S == 0) return 1;
  int lb = lowbit(S);
  // if(S == lb && calc_sum(S) <= 0) return 0;
  if(f[S] != INF) return f[S];
  f[S] = 0;
  for(int i = 0; i < n; i ++) {
    if(((1 << i) & S) != 0 && (S == lb || calc_sum(S ^ (1 << i)) > 0)) {
      f[S] = (f[S] + calc_f(S ^ (1 << i))) % ha;
    }
  }
  return f[S];
}
int g[1 << maxn]; // <= 0
int calc_g(int S) {
  if(S == 0) return 1;
  if(calc_sum(S) > 0) return 0;
  if(g[S] != INF) return g[S];
  g[S] = 0;
  for(int i = 0; i < n; i ++) {
    if((1 << i) & S) {
      g[S] = (g[S] + calc_g(S ^ (1 << i))) % ha;
    }
  }
  return g[S];
}

int main() {
  process_ma();
  memset(sum, 0x7f, sizeof(sum));
  memset(f, 0x7f, sizeof(f));
  memset(g, 0x7f, sizeof(g));
  scanf("%d", &n);
  for(int i = 0; i < n; i ++) {
    scanf("%d", &A[i]);
  }
  int ans = 0;
  int al = (1 << n) - 1;
  for(int s = 0; s < (1 << n); s ++) {
    ll fv = calc_f(s), gv = calc_g(al ^ s);
#ifdef LOCAL
    printf("f[%d] : %lld\n", s, fv);
#endif
    fv = (fv * gv) % (ll(ha));
    ll S = calc_sum(s); while(S < 0) S += ha;
    S = S % (ll(ha));
    fv = (fv * S) % (ll(ha));
    ans = (ans + (int(fv))) % ha;
#ifdef LOCAL
    printf("delta[%d] : %lld\n", s, fv);
#endif
  }
  printf("%d\n", ans);
  return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter