TopCoder SRM 622 Div1 Hard

問題

aまたはbからなる文字列が与えられる。
このとき、文字列に2回以上、重ならないで出現する空でない部分文字列の個数を求めよ。
文字列が同じときは一つと数える。

制約条件

文字列は乱数により生成される。
n≦10^5

方針

文字列のSuffix Array, LCPを作る。
重なってもよいので2度以上出現するユニークな部分文字列が何種類あるかは、
Σmax(LCP[i] - LCP[i - 1])を取ればよい。


たとえば文字列がこんな具合だとする。

i: 1 lcp: 0 si: 0  aaaabb
i: 2 lcp: 3 si: 1  aaabb <-ここでa, aa, aaaが二度出てきたということがわかる
i: 3 lcp: 2 si: 2  aabb
i: 4 lcp: 1 si: 3  abb
i: 5 lcp: 0 si: 5  b
i: 6 lcp: 1 si: 4  bb <-ここでbが二度出てきたことがわかる

LCPを見れば、2度以上出現する文字列が全部わかるが、
そのうちユニークなものを求めたいので、各文字について、最初に出現したところを見てやればよい。
それはいつ現れるかというと、LCPが最初に増えたところなので、LCPが増えたときに、増えた分だけ足せばユニークな(2度以上出現する)文字列の個数になる。


それで、この問題では重なってはいけないという制約があるので、
2度出現する文字が出てきたときに、重ならないのがどこまでかを考えなければならない。
上の表でi = 2のときにa, aa, aaaが出現する。(二度出てきたとわかる)


これを全部重複しないかいちいち数えると間に合わない。


出現した文字列のうち、長さLのものが重複せずに出現可能ならば、l≦Lであるようなl文字の文字列も重複せずに出現可能。
したがって、重複しないものの上限を二分探索によって求めることを考える。


長さLのものが重複せずに出現可能かどうかはどう判定すればいいかというと、LCPのテーブルを二分探索して、
min(LCP[x], LCP[x+1], ..., LCP[i], LCP[i + 1], ..., LCP[y-1], LCP[y])がL以上であるような最小のxと最大のyをそれぞれ求め、


この区間[x, y]内のsiの最小値と最大値を見れば、この長さLの文字列の出現位置の一番左の位置と一番右の位置がわかるため、左の位置 + L ≦右の位置が成り立っているかどうかで、重複なく出現が可能であるか判定することができる。


LCP上の二分探索および区間内の最小最大値を求める処理は、
適切なデータ構造(segment treeとかsparse tableとか)を使ってO(logn)くらいの時間計算量で求めてやる。

ソースコード

class StringsNightmareAgain {
	public:
	long long UniqueSubstrings(int a, int b, int c, int d, int n) {
		
		memset(lcp, 0, sizeof(lcp));
		
		string s(n, 'a');
		rep(i, a){
			b = ((ll)b * c + d) % n;
			s[b] = 'b';
		}
		sprintf(t, "%s", s.c_str());
		buildSA(t);
		buildLCP(t);
		
		ll ans = 0;
		int *rmq = buildRMQ(lcp, n + 1);
		int *prmq = buildRMQ(si, n + 1); rep(i, n + 1) si[i] *= -1;
		int *mrmq = buildRMQ(si, n + 1); rep(i, n + 1) si[i] *= -1;
		
		rep(i, n + 1) cerr << "i: "<<i<<" lcp: "<<lcp[i]<<" si: "<<si[i]<<"  "<< t + si[i] << endl;
		
		for(int i = 2; i <= n; i++){
			if(lcp[i - 1] >= lcp[i]) continue;
			
			int l = lcp[i - 1], r = lcp[i] + 1, m;
			while(l + 1 < r){
				m = (l + r) / 2;
				
				int x, y, lo, hi, mid;
				
				lo = 0; hi = i;
				while(lo + 1 < hi){
					mid = (lo + hi) / 2;
					if(minimum(mid + 1, i, rmq, n + 1) >= m) hi = mid;
					else lo = mid;
				}
				x = hi;
				
				
				lo = i; hi = n + 1;
				while(lo + 1 < hi){
					mid = (lo + hi) / 2;
					if(minimum(i + 1, mid, rmq, n + 1) >= m) lo = mid;
					else hi = mid;
				}
				y = lo;
				int L = minimum(x, y, prmq, n + 1);
				int R =-minimum(x, y, mrmq, n + 1);
				if(L + m <= R) l = m;
				else r = m;
			}
			ans += l - lcp[i - 1];
		}
		return ans;
	}