[LibreOJ 2137][ZJOI2015]诸神眷顾的幻想乡
复习SAM了呢~
那个度数为1的点至多有20个的条件非常神奇……让我们想想怎么用。
我们发现,(钦定根后)在树上有一条路径是弯的这种情况非常不好统计,但如果是直着下来就很好说(把所有从根到一个点的路径扔到SAM里,然后是经典题)。但是,任何一条路一定会在某个度数为1的点为根的情况下是直的(可以意会一下吧(逃
然后我们从那20个点每个点当根DFS一遍,把搞出来的每一条从根开始的路径放前缀Trie里。然后对前缀Trie构一个SAM就行啦~
代码:
#include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> #include <utility> #include <queue> #include <vector> const int BUFSIZE = 128 * 1024 * 1024; char buf[BUFSIZE]; void *alloc(size_t size) { static char *cur = buf; if(cur - buf + size > BUFSIZE) { return malloc(size); } else { char *ret = cur; cur += size; return ret; } } const int maxn = 100005; std::vector<int> G[maxn]; int deg[maxn]; int ssiz; struct TNode { TNode *ch[10]; }; TNode *alloc_tn() { auto ret = (TNode*)alloc(sizeof(TNode)); memset(ret -> ch, 0, sizeof(ret -> ch)); return ret; } TNode *step(TNode *o, int c) { if(!(o -> ch[c])) { o -> ch[c] = alloc_tn(); } return (o -> ch[c]); } TNode *trt; int col[maxn]; void dfs(int x, int fa, TNode *last) { TNode *np = step(last, col[x]); for(auto v : G[x]) { if(v != fa) { dfs(v, x, np); } } } struct Node { int len; Node *fa; Node *ch[10]; }; std::vector<Node*> pool; Node *alloc_node(int len = 0, Node *fa = NULL) { Node *ret = (Node*)alloc(sizeof(Node)); ret -> len = len; ret -> fa = fa; memset(ret -> ch, 0, sizeof(ret -> ch)); pool.push_back(ret); return ret; } Node *rt; Node *extend(int c, Node *last) { Node *np = alloc_node(last -> len + 1); Node *p = last; while(p != NULL && p -> ch[c] == NULL) { p -> ch[c] = np; p = p -> fa; } if(p == NULL) { np -> fa = rt; } else { Node *q = p -> ch[c]; if(q -> len == p -> len + 1) { np -> fa = q; } else { Node *nq = alloc_node(p -> len + 1, q -> fa); memcpy(nq -> ch, q -> ch, sizeof(q -> ch)); q -> fa = np -> fa = nq; while(p != NULL && p -> ch[c] == q) { p -> ch[c] = nq; p = p -> fa; } } } return np; } void dfs_2(TNode *o, Node *last) { for(int i = 0; i < ssiz; i ++) { TNode *v = o -> ch[i]; if(!v) continue; Node *np = extend(i, last); dfs_2(v, np); } } using ll = long long; int main() { int n; scanf("%d%d", &n, &ssiz); for(int i = 1; i <= n; i ++) { scanf("%d", &col[i]); } trt = alloc_tn(); for(int i = 1; i <= n - 1; i ++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); deg[u] ++; deg[v] ++; } for(int i = 1; i <= n; i ++) { if(deg[i] == 1) { dfs(i, 0, trt); } } rt = alloc_node(); dfs_2(trt, rt); ll ans = 0; for(auto p : pool) { if(p -> fa) { ans += p -> len - p -> fa -> len; } } printf("%lld\n", ans); return 0; }