TopCoder SRM 589 Div1 Hard FlippingBitsDiv1

問題

長さnの01からなる文字列が与えられる。
この文字列に大して以下の操作を行うことができる。

  • 任意の1文字の0,1を反転する
  • 先頭からk*m文字の0,1を反転する(ただしkは任意の自然数

操作を終えたあと、文字列の先頭からn-m文字の部分文字列と、末尾からn-m文字の部分文字列が一致する必要がある。
必要な操作の回数の最小値を求めよ。

制約条件

n≦300
m≦n

方針

先頭からの部分文字列と末尾からの文字列が重なるとき、文字列は周期mになる。
以下重なる場合について考える。


mがでかいとき、操作2で反転する仕方というのは2^(300/m)になって小さくなる。
mが小さいとき、周期ごとに文字列の0, 1を決めつけることができて、決め打ちの仕方は2^m通り。


なので、√300を境界にしてアルゴリズムを切り替えればいい。
部分文字列が重ならないときは前者のアルゴリズムを使えばいい。


前者を詳しく書くと、反転する区間を決め打ちすると、それぞれの文字が確定するので、
後は全部を一致させるために、unionごとに0, 1の多数決を取ればいい。


後者は、周期で文字を決め打ちした後は、幅m毎に反転するかしないかのDPすればよくって、
dp[何番目の(長さmの)区間か][直前で反転したか][一度でも反転したか][今後反転しない]
を更新していけばいい。


反転する区間を決めた後のコストの算出方法がちょっとハマりやすい。
基本的に01が切り替わるところでコストが+1なのだけれど、0だけの場合、最後が0で終わる場合などに注意する。

ソースコード

int s[300], n, m, l;
int p[300], dp[301][2][2][2];
int root(int x){
	if(x == p[x]) return x;
	return p[x] = root(p[x]);
}

int calc2(int bit){
	int res = 0, one[300] = {}, zero[300] = {};
	rep(i, l) if(i && (bit >> i & 1) != (bit >> i-1 & 1)) res++;
	if(bit) res++;
	if(bit && !(bit >> l-1 & 1)) res--;
	
	rep(i, n){
		int r = root(i);
		int d = s[i] ^ (bit >> (i / m) & 1);
		
		if(d == 0) zero[r]++;
		else one[r]++;
	}
	rep(i, 300) res += min(one[i], zero[i]);
	return res;
}
int calc(int bit){
	rep(i, l+1) rep(j, 2) rep(k, 2) rep(x, 2) dp[i][j][k][x] = inf;
	dp[0][0][0][0] = 0;
	
	rep(i, l) rep(prev, 2) rep(one, 2) rep(f, 2) if(dp[i][prev][one][f] < inf){
		rep(cur, 2) if(!f || cur == prev){
			int cost = dp[i][prev][one][f];
			if(prev != cur) cost++;
			if(i > 0 && !one && cur) cost++;
			
			rep(j, m){
				if(i * m + j >= n) continue;
				if((s[i * m + j] ^ cur) != (bit >> j & 1)) cost++;
			}
			dp[i+1][cur][one || cur][f] = min(dp[i+1][cur][one || cur][f], cost);
			
			if(prev && !cur) cost--;
			dp[i+1][cur][one || cur][1] = min(dp[i+1][cur][one || cur][1], cost);
		}
	}
	int res = inf;
	rep(i, 2) rep(j, 2) rep(k, 2) res = min(res, dp[l][i][j][k]);
	return res;
}

class FlippingBitsDiv1 {
	public:
	int getmin(vector <string> S, int M) {
		string t;
		rep(i, S.size()) t += S[i];
		n = t.size();
		m = M;
		l = (n + m - 1) / m;
		rep(i, n) s[i] = t[i] - '0', p[i] = i;
		
		rep(i, n - m){
			int a = root(i), b = root(m + i);
			if(a != b) p[b] = a;
		}
		
		if(m >= 18 || n < 2 * m){
			int ans = inf;
			rep(i, 1 << l) ans = min(ans, calc2(i));
			return ans;
		}
		int ans = inf;
		rep(i, 1 << m) ans = min(ans, calc(i));
		return ans;
	}
};