[LibreOJ 2137][ZJOI2015]诸神眷顾的幻想乡
复习SAM了呢~
那个度数为1的点至多有20个的条件非常神奇……让我们想想怎么用。
我们发现,(钦定根后)在树上有一条路径是弯的这种情况非常不好统计,但如果是直着下来就很好说(把所有从根到一个点的路径扔到SAM里,然后是经典题)。但是,任何一条路一定会在某个度数为1的点为根的情况下是直的(可以意会一下吧(逃
然后我们从那20个点每个点当根DFS一遍,把搞出来的每一条从根开始的路径放前缀Trie里。然后对前缀Trie构一个SAM就行啦~
代码:
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <utility>
#include <queue>
#include <vector>
const int BUFSIZE = 128 * 1024 * 1024;
char buf[BUFSIZE];
void *alloc(size_t size) {
static char *cur = buf;
if(cur - buf + size > BUFSIZE) {
return malloc(size);
} else {
char *ret = cur; cur += size;
return ret;
}
}
const int maxn = 100005;
std::vector<int> G[maxn];
int deg[maxn];
int ssiz;
struct TNode {
TNode *ch[10];
};
TNode *alloc_tn() {
auto ret = (TNode*)alloc(sizeof(TNode));
memset(ret -> ch, 0, sizeof(ret -> ch));
return ret;
}
TNode *step(TNode *o, int c) {
if(!(o -> ch[c])) {
o -> ch[c] = alloc_tn();
}
return (o -> ch[c]);
}
TNode *trt;
int col[maxn];
void dfs(int x, int fa, TNode *last) {
TNode *np = step(last, col[x]);
for(auto v : G[x]) {
if(v != fa) {
dfs(v, x, np);
}
}
}
struct Node {
int len; Node *fa;
Node *ch[10];
};
std::vector<Node*> pool;
Node *alloc_node(int len = 0, Node *fa = NULL) {
Node *ret = (Node*)alloc(sizeof(Node));
ret -> len = len; ret -> fa = fa;
memset(ret -> ch, 0, sizeof(ret -> ch));
pool.push_back(ret);
return ret;
}
Node *rt;
Node *extend(int c, Node *last) {
Node *np = alloc_node(last -> len + 1);
Node *p = last;
while(p != NULL && p -> ch[c] == NULL) {
p -> ch[c] = np; p = p -> fa;
}
if(p == NULL) {
np -> fa = rt;
} else {
Node *q = p -> ch[c];
if(q -> len == p -> len + 1) {
np -> fa = q;
} else {
Node *nq = alloc_node(p -> len + 1, q -> fa);
memcpy(nq -> ch, q -> ch, sizeof(q -> ch));
q -> fa = np -> fa = nq;
while(p != NULL && p -> ch[c] == q) {
p -> ch[c] = nq; p = p -> fa;
}
}
}
return np;
}
void dfs_2(TNode *o, Node *last) {
for(int i = 0; i < ssiz; i ++) {
TNode *v = o -> ch[i];
if(!v) continue;
Node *np = extend(i, last);
dfs_2(v, np);
}
}
using ll = long long;
int main() {
int n; scanf("%d%d", &n, &ssiz);
for(int i = 1; i <= n; i ++) {
scanf("%d", &col[i]);
}
trt = alloc_tn();
for(int i = 1; i <= n - 1; i ++) {
int u, v; scanf("%d%d", &u, &v);
G[u].push_back(v); G[v].push_back(u);
deg[u] ++; deg[v] ++;
}
for(int i = 1; i <= n; i ++) {
if(deg[i] == 1) {
dfs(i, 0, trt);
}
}
rt = alloc_node();
dfs_2(trt, rt);
ll ans = 0;
for(auto p : pool) {
if(p -> fa) {
ans += p -> len - p -> fa -> len;
}
}
printf("%lld\n", ans);
return 0;
}