UTPC 2011 L L番目の数字(AOJ 2270)
問題
日本語なので本文参照(http://www.utpc.jp/2011/)
N頂点からなる木が与えられる。
i番目の頂点には数字x[i]が書かれている。
この木に対して次のようなQ個のクエリに答えよ。
クエリ:頂点v[i]からw[i]へのパス上に書かれている数字のうち、l[i]番目に小さいものを求める。
制約条件
N≦10^5
Q≦10^5
x[i]≦10^9
方針
解説スライドの通り。
全ての頂点に対して、根からのパスに含まれる数字を要素にもつ二分探索木を作る。
これを永続にすることで、全部合わせてメモリ領域がO(nlogn)で作れる。
ちなみに、永続化は、二分探索木において、
根から葉までのパスの部分までしか前回と変わらないので、
そこだけノードを新規に作るようにするだけ。
根までの二分探索木が出来ると、
x以下の要素の個数がO(logn)で求まる。
すると、LCAを使えばパス上のx以下の要素の個数も同じオーダーで求まる。
xについて二分探索して、個数がl以上になる最小のxが答え。
永続二分探索木よりも、永続segmenttreeを書いたほうがだいぶ実装が楽。
自分はspaghetti sourceの二分探索木を、永続化した。
(最初に全ての要素を0個にした二分探索木を作っておけば、
回転を永続化しなくてすむので、結局そんなに実装量は多くないのだけど)
ソースコード
template <class T> struct avl_tree { struct node { T key; int size, height, sum; node *child[2]; node(const T &key) : key(key), size(1), height(1), sum(key.second) { child[0] = child[1] = 0; } } *root; typedef node *pointer; avl_tree() { root = NULL; } pointer find(const T &key) { return find(root, key); } node *find(node *t, const T &key) { if (t == NULL) return NULL; if (key.first == t->key.first) return t; else if (key < t->key) return find(t->child[0], key); else return find(t->child[1], key); } void insert(const T &key) { root = insert(root, new node(key)); } node *insert(node *t, node *x) { if (t == NULL) return x; if(x->key.first == t->key.first){ t->key.second = x->key.second; return t; } if (x->key <= t->key) t->child[0] = insert(t->child[0], x); else t->child[1] = insert(t->child[1], x); t->size += 1; return balance(t); } node *move_down(node *t, node *rhs) { if (t == NULL) return rhs; t->child[1] = move_down(t->child[1], rhs); return balance(t); } #define sz(t) (t ? t->size : 0) #define ht(t) (t ? t->height : 0) #define sm(t) (t ? t->sum : 0) node *update(node *t){ if(!t) return t; t->height = max(ht(t->child[0]), ht(t->child[1])) + 1; t->size = sz(t->child[0]) + sz(t->child[1]) + 1; t->sum = sm(t->child[0]) + sm(t->child[1]) + t->key.second; return t; } node *rotate(node *t, int l, int r) { node *s = t->child[r]; t->child[r] = s->child[l]; s->child[l] = balance(t); update(t); return balance(update(s)); } node *balance(node *t) { for (int i = 0; i < 2; ++i) { if (ht(t->child[!i]) - ht(t->child[i]) < -1) { if (ht(t->child[i]->child[!i]) - ht(t->child[i]->child[i]) > 0) t->child[i] = rotate(t->child[i], i, !i); return rotate(t, !i, i); } } return update(t); } //t以下の部分木のうち、キーがbound以下である要素の個数を求める。 int sum(int bound, node* t){ if(!t) return 0; int res = 0; if(t->key.first <= bound){ res += t->key.second; res += sm(t->child[0]); res += sum(bound, t->child[1]); } else{ res += sum(bound, t->child[0]); } return res; } //挿入の永続版。二分探索木の変更のある部分だけnewする。 node *insertp(node *t, T key) { if (t == NULL) return new node(key); node *res = new node(t->key); res->size = t->size; res->height = t->height; res->child[0] = t->child[0]; res->child[1] = t->child[1]; if(key.first == t->key.first){ res->key.second = key.second; return update(res); } if (key <= t->key) res->child[0] = insertp(t->child[0], key); else res->child[1] = insertp(t->child[1], key); return update(res); } void insertp(T key){ root = insertp(root, key); } int sum(int bound){ return sum(bound, root); } }; const int MX = 100010; const int MX_L = 17; int n, q; int x[MX]; int parent[MX_L][MX], depth[MX]; vector<vi> e; avl_tree<pi> ts[MX]; void rec(int c, int p, avl_tree<pi> &t, int d){ avl_tree<pi>::pointer pt = t.find(mp(x[c], 0)); ts[c].root = t.root; ts[c].insertp(mp(x[c], pt->key.second + 1)); depth[c] = d; parent[0][c] = p; rep(i, e[c].size()) if(e[c][i] != p) rec(e[c][i], c, ts[c], d + 1); } int lca(int a, int b){ if(depth[a] > depth[b]) swap(a, b); for(int i = MX_L - 1; i >= 0; i--) if(depth[parent[i][b]] >= depth[a]) b = parent[i][b]; if(a == b) return a; for(int i = MX_L - 1; i >= 0; i--) if(parent[i][a] != parent[i][b]){ a = parent[i][a]; b = parent[i][b]; } return parent[0][a]; } int main() { scanf("%d%d", &n, &q); rep(i, n) scanf("%d", x + i); e.resize(n); rep(i, n - 1){ int a, b; scanf("%d%d", &a, &b); a--; b--; e[a].pb(b); e[b].pb(a); } rep(i, n) ts[n].insert(mp(x[i], 0)); rec(0, n, ts[n], 1); parent[0][n] = n; rep(i, MX_L - 1) rep(j, n) parent[i + 1][j] = parent[i][parent[i][j]]; while(q--){ int v, w, l; int lo = 0, hi = inf, mid; scanf("%d%d%d", &v, &w, &l); v--; w--; int a = lca(v, w), b = parent[0][a]; while(lo + 1 < hi){ mid = (lo + hi) / 2; int cnt = ts[w].sum(mid) + ts[v].sum(mid) - ts[a].sum(mid) - ts[b].sum(mid); if(cnt >= l) hi = mid; else lo = mid; } printf("%d\n", hi); } return 0; }