Codeforces 392(#230 Div1) D. Three Arrays

問題

n項からなる数列A, B, Cが与えられる。
Aの先頭u項、Bの先頭v項、Cの先頭w項の和集合と、
A, B, Cの和集合が一致するような(u, v, w)について、u + v + wの最小値を求めよ。

制約条件

n≦10^5
各項≦10^9

方針

uを固定すると、(v, wの最適値)は、vが増加したときwが単調に減少する列になる。
この、wが下がった点だけを覚えておくと、列の折れ線が再現できる。


この列を記憶できているときにuを1減らすことを考える。
A[u]がB, Cに出現する位置をそれぞれib, icとする。


すると、(v, w)の点うち、修正しなければいけないのは、
ib<vかつ、w<icであるようなもの。つまり(ib, ic)の厳密に左下にある点。


左下になにも点がなければ何もおこらない。


あるときは、このような点を全て削除すると同時に、

  • 折れ線とy = icが交わっていた点に新たな点を追加、
  • 折れ線とx = ibが交わっていた点に新たな点を追加、

するようにすれば、2点の追加だけで列の更新ができる。
(折れ線と直線がオーバーラップしていた場合は追加しないでよい)
更新は一回につきO(logn)で出来て全体でO(n)回、
削除も一つの点について高々一回しかおこらないので全体でO(n)回で、
O(nlogn)時間計算量で計算できる。


最初u = nのときに、(v, w)を尺取法で全部求めているのだけれど、
他の人の解法とかを見るとどうも不要なよう。なんだけどよくわからない。


segment treeを使っても出来るらしい。

ソースコード

const int MX = 100000;
int n, a[MX], b[MX], c[MX];
map<int, int> pa, pb, pc, sum;
set<pi> s;

void ins(int i, int j){
	set<pi>::iterator it = s.lower_bound(mp(i, j));
	if(it != s.begin()){
		it--;
		if(it->second <= j) return;
	}
	s.insert(mp(i, j));
	sum[i + j]++;
}
void del(int i, int j){
	s.erase(mp(i, j));
	if(--sum[i + j] == 0) sum.erase(i + j);
}
void shaku(){
	map<int, int> cnt;
	int kind = 0, tot = 0;
	rep(i, n){
		if(cnt[a[i]]++ == 0) tot++;
		if(cnt[b[i]]++ == 0) tot++;
		if(cnt[c[i]]++ == 0) tot++;
	}
	cnt.clear();
	rep(i, n){
		if(cnt[a[i]]++ == 0) kind++;
		if(cnt[c[i]]++ == 0) kind++;
	}
	
	int i = 0, j = n;
	for(; i <= n; i++){
		while(j > 0 && cnt[c[j - 1]] > 1) cnt[c[--j]]--;
		if(tot == kind) ins(i, j);
		if(i == n) break;
		if(cnt[b[i]]++ == 0) kind++;
	}
}

int main(){
	cin >> n;
	rep(i, n) cin >> a[i];
	rep(i, n) cin >> b[i];
	rep(i, n) cin >> c[i];
	
	for(int i = n - 1; i >= 0; i--){
		pa[a[i]] = i + 1;
		pb[b[i]] = i + 1;
		pc[c[i]] = i + 1;
	}
	
	shaku();
	const int BIG = 3 * n + 100;
	sum[BIG]++;
	int ans = n + sum.begin()->first;
	
	for(int i = n - 1; i >= 0; i--){
		
		if(i + 1 == pa[a[i]]){
			
			int ib = pb.count(a[i]) ? pb[a[i]] : BIG;
			int ic = pc.count(a[i]) ? pc[a[i]] : BIG;
			
			set<pi>::iterator it = s.lower_bound(mp(ib, ic));
			vector<pi> er;
			while(it != s.begin()){
				it--;
				if(it->second >= ic) break;
				er.pb(*it);
			}
			
			if(er.empty());
			else{
				int m = er.size();
				rep(j, m) del(er[j].first, er[j].second);
				
				it = s.lower_bound(mp(er[0].first + 1, -1));
				if(it == s.end() || it->first > ib) ins(ib, er[0].second);
				
				it = s.lower_bound(mp(er[m - 1].first, -1));
				if(it != s.begin()){
					it--;
					if(it->second > ic) ins(er[m - 1].first, ic);
				}
				else ins(er[m - 1].first, ic);
			}
		}
		ans = min(ans, i + sum.begin()->first);
	}
	cout << ans << endl;
	
	return 0;
}