UTPC 2011 L L番目の数字(AOJ 2270)

問題

日本語なので本文参照(http://www.utpc.jp/2011/


N頂点からなる木が与えられる。
i番目の頂点には数字x[i]が書かれている。
この木に対して次のようなQ個のクエリに答えよ。
クエリ:頂点v[i]からw[i]へのパス上に書かれている数字のうち、l[i]番目に小さいものを求める。

制約条件

N≦10^5
Q≦10^5
x[i]≦10^9

方針

解説スライドの通り。
全ての頂点に対して、根からのパスに含まれる数字を要素にもつ二分探索木を作る。
これを永続にすることで、全部合わせてメモリ領域がO(nlogn)で作れる。


ちなみに、永続化は、二分探索木において、
根から葉までのパスの部分までしか前回と変わらないので、
そこだけノードを新規に作るようにするだけ。


根までの二分探索木が出来ると、
x以下の要素の個数がO(logn)で求まる。
すると、LCAを使えばパス上のx以下の要素の個数も同じオーダーで求まる。
xについて二分探索して、個数がl以上になる最小のxが答え。


永続二分探索木よりも、永続segmenttreeを書いたほうがだいぶ実装が楽。
自分はspaghetti sourceの二分探索木を、永続化した。
(最初に全ての要素を0個にした二分探索木を作っておけば、
回転を永続化しなくてすむので、結局そんなに実装量は多くないのだけど)

ソースコード

template <class T>
struct avl_tree {
  struct node {
    T key;
    int size, height, sum;
    node *child[2];
    node(const T &key) : key(key), size(1), height(1), sum(key.second) {
      child[0] = child[1] = 0; }
  } *root;
  typedef node *pointer;
  avl_tree() { root = NULL; }

  pointer find(const T &key) { return find(root, key); }
  node *find(node *t, const T &key) {
    if (t == NULL) return NULL;
    if (key.first == t->key.first) return t;
    else if (key < t->key) return find(t->child[0], key);
    else                   return find(t->child[1], key);
  }
  void insert(const T &key) { root = insert(root, new node(key)); }
  node *insert(node *t, node *x) {
    if (t == NULL) return x;
  	if(x->key.first == t->key.first){
  		t->key.second = x->key.second;
  		return t;
  	}
    if (x->key <= t->key) t->child[0] = insert(t->child[0], x);
    else                  t->child[1] = insert(t->child[1], x);
    t->size += 1;
    return balance(t);
  }
  node *move_down(node *t, node *rhs) {
    if (t == NULL) return rhs;
    t->child[1] = move_down(t->child[1], rhs);
    return balance(t);
  }
#define sz(t) (t ? t->size : 0)
#define ht(t) (t ? t->height : 0)
#define sm(t) (t ? t->sum : 0)
	node *update(node *t){
		if(!t) return t;
		t->height = max(ht(t->child[0]), ht(t->child[1])) + 1;
		t->size = sz(t->child[0]) + sz(t->child[1]) + 1;
		t->sum = sm(t->child[0]) + sm(t->child[1]) + t->key.second;
		return t;
	}
  node *rotate(node *t, int l, int r) {
    node *s = t->child[r];
    t->child[r] = s->child[l];
    s->child[l] = balance(t);
  	update(t);
    return balance(update(s));
  }
  node *balance(node *t) {
    for (int i = 0; i < 2; ++i) {
      if (ht(t->child[!i]) - ht(t->child[i]) < -1) {
        if (ht(t->child[i]->child[!i]) - ht(t->child[i]->child[i]) > 0)
          t->child[i] = rotate(t->child[i], i, !i);
      	return rotate(t, !i, i);
      }
    }
  	return update(t);
  }
	//t以下の部分木のうち、キーがbound以下である要素の個数を求める。
	int sum(int bound, node* t){
		if(!t) return 0;
		int res = 0;
		if(t->key.first <= bound){
			res += t->key.second;
			res += sm(t->child[0]);
			res += sum(bound, t->child[1]);
		}
		else{
			res += sum(bound, t->child[0]);
		}
		return res;
	}
	//挿入の永続版。二分探索木の変更のある部分だけnewする。
	node *insertp(node *t, T key) {
		if (t == NULL) return new node(key);
		
		node *res = new node(t->key);
		res->size = t->size;
		res->height = t->height;
		res->child[0] = t->child[0];
		res->child[1] = t->child[1];
		
		if(key.first == t->key.first){
			res->key.second = key.second;
			return update(res);
		}
		if (key <= t->key) res->child[0] = insertp(t->child[0], key);
		else               res->child[1] = insertp(t->child[1], key);
		return update(res);
	}
	void insertp(T key){ root = insertp(root, key); }
	int sum(int bound){ return sum(bound, root); }
};

const int MX = 100010;
const int MX_L = 17;
int n, q;
int x[MX];
int parent[MX_L][MX], depth[MX];
vector<vi> e;
avl_tree<pi> ts[MX];

void rec(int c, int p, avl_tree<pi> &t, int d){
	avl_tree<pi>::pointer pt = t.find(mp(x[c], 0));
	ts[c].root = t.root;
	ts[c].insertp(mp(x[c], pt->key.second + 1));
	depth[c] = d;
	parent[0][c] = p;
	rep(i, e[c].size()) if(e[c][i] != p)
	rec(e[c][i], c, ts[c], d + 1);
}
int lca(int a, int b){
	if(depth[a] > depth[b]) swap(a, b);
	for(int i = MX_L - 1; i >= 0; i--) if(depth[parent[i][b]] >= depth[a])
		b = parent[i][b];
	if(a == b) return a;
	for(int i = MX_L - 1; i >= 0; i--) if(parent[i][a] != parent[i][b]){
		a = parent[i][a];
		b = parent[i][b];
	}
	return parent[0][a];
}

int main()
{
	scanf("%d%d", &n, &q);
	rep(i, n) scanf("%d", x + i);
	e.resize(n);
	rep(i, n - 1){
		int a, b;
		scanf("%d%d", &a, &b);
		a--; b--;
		e[a].pb(b);
		e[b].pb(a);
	}
	rep(i, n) ts[n].insert(mp(x[i], 0));
	rec(0, n, ts[n], 1);
	
	parent[0][n] = n;
	rep(i, MX_L - 1) rep(j, n) parent[i + 1][j] = parent[i][parent[i][j]];
	while(q--){
		int v, w, l;
		int lo = 0, hi = inf, mid;
		scanf("%d%d%d", &v, &w, &l);
		v--; w--;
		int a = lca(v, w), b = parent[0][a];
		while(lo + 1 < hi){
			mid = (lo + hi) / 2;
			
			int cnt = ts[w].sum(mid) + ts[v].sum(mid) - ts[a].sum(mid) - ts[b].sum(mid);
			if(cnt >= l) hi = mid;
			else lo = mid;
		}
		printf("%d\n", hi);
	}
	return 0;
}