[LibreOJ 2137][ZJOI2015]诸神眷顾的幻想乡

复习SAM了呢~

那个度数为1的点至多有20个的条件非常神奇……让我们想想怎么用。

我们发现,(钦定根后)在树上有一条路径是弯的这种情况非常不好统计,但如果是直着下来就很好说(把所有从根到一个点的路径扔到SAM里,然后是经典题)。但是,任何一条路一定会在某个度数为1的点为根的情况下是直的(可以意会一下吧(逃

然后我们从那20个点每个点当根DFS一遍,把搞出来的每一条从根开始的路径放前缀Trie里。然后对前缀Trie构一个SAM就行啦~

代码:

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <utility>
#include <queue>
#include <vector>
const int BUFSIZE = 128 * 1024 * 1024;
char buf[BUFSIZE];
void *alloc(size_t size) {
  static char *cur = buf;
  if(cur - buf + size > BUFSIZE) {
    return malloc(size);
  } else {
    char *ret = cur; cur += size;
    return ret;
  }
}

const int maxn = 100005;
std::vector<int> G[maxn];
int deg[maxn];

int ssiz;
struct TNode {
  TNode *ch[10];
};
TNode *alloc_tn() {
  auto ret = (TNode*)alloc(sizeof(TNode));
  memset(ret -> ch, 0, sizeof(ret -> ch));
  return ret;
}
TNode *step(TNode *o, int c) {
  if(!(o -> ch[c])) {
    o -> ch[c] = alloc_tn();
  }
  return (o -> ch[c]);
}
TNode *trt;
int col[maxn];
void dfs(int x, int fa, TNode *last) {
  TNode *np = step(last, col[x]);
  for(auto v : G[x]) {
    if(v != fa) {
      dfs(v, x, np);
    }
  }
}

struct Node {
  int len; Node *fa;
  Node *ch[10];
};
std::vector<Node*> pool;
Node *alloc_node(int len = 0, Node *fa = NULL) {
  Node *ret = (Node*)alloc(sizeof(Node));
  ret -> len = len; ret -> fa = fa;
  memset(ret -> ch, 0, sizeof(ret -> ch));
  pool.push_back(ret);
  return ret;
}
Node *rt;
Node *extend(int c, Node *last) {
  Node *np = alloc_node(last -> len + 1);
  Node *p = last;
  while(p != NULL && p -> ch[c] == NULL) {
    p -> ch[c] = np; p = p -> fa;
  }
  if(p == NULL) {
    np -> fa = rt;
  } else {
    Node *q = p -> ch[c];
    if(q -> len == p -> len + 1) {
      np -> fa = q;
    } else {
      Node *nq = alloc_node(p -> len + 1, q -> fa);
      memcpy(nq -> ch, q -> ch, sizeof(q -> ch));
      q -> fa = np -> fa = nq;
      while(p != NULL && p -> ch[c] == q) {
        p -> ch[c] = nq; p = p -> fa;
      }
    }
  }
  return np;
}
void dfs_2(TNode *o, Node *last) {
  for(int i = 0; i < ssiz; i ++) {
    TNode *v = o -> ch[i];
    if(!v) continue;
    Node *np = extend(i, last);
    dfs_2(v, np);
  }
}

using ll = long long;
int main() {
  int n; scanf("%d%d", &n, &ssiz);
  for(int i = 1; i <= n; i ++) {
    scanf("%d", &col[i]);
  }
  trt = alloc_tn();
  for(int i = 1; i <= n - 1; i ++) {
    int u, v; scanf("%d%d", &u, &v);
    G[u].push_back(v); G[v].push_back(u);
    deg[u] ++; deg[v] ++;
  }
  for(int i = 1; i <= n; i ++) {
    if(deg[i] == 1) {
      dfs(i, 0, trt);
    }
  }
  rt = alloc_node();
  dfs_2(trt, rt);
  ll ans = 0;
  for(auto p : pool) {
    if(p -> fa) {
      ans += p -> len - p -> fa -> len;
    }
  }
  printf("%lld\n", ans);
  return 0;
}