TopCoder SRM 596 Div1 Hard SparseFactorial

問題

nのsparce factorialを、
f(n) = n * (n - 1) * (n - 4) * (n - 9) * ……と定義する。


lo以上hi以下の整数nで、f(n)がdivisorの倍数であるものはいくつあるか、求めよ。

制約条件

lo, hi≦10^12
2≦divisor≦10^6

方針

editorialみた。


f(n)がmで割り切れるとき、f(n + m)もmで割り切れる。
したがって、n % m = iとするとiの値0, 1, 2, ..., m - 1ごとに、
それぞれf(n)がmで割り切れる最小のn(これをtable[i])とするを見つけてやればよい。


m1, m2ごとにtableがわかっていたら、m1 * m2に対するtableは簡単に作れる。
table12[i] = max(table1[i % m1], table2[i % m2])とすればいいだけ。


ではdivisorをp1^e1 * p2^e2 * ……と素因数分解したときに、
mi = pi^eiとして、それぞれにテーブルを作ってやればよさそう。


tableは愚直に作ると、
iについてm回ループを回して、そこからkのループをp * e回まわすことになる。
(f(n) ≡ Π(i - k^2) と考えて)


なんでkがp * e回で十分かと言えば、kの値はmod pで考えると最大pでループし、
一回のループで一つ以上pで割れる回数が増えるから。(割れるなら)


でもこれだとTLE.(最悪m = pだとm^2, divisor^2がかかってしまうから)
で、ループの順序を入れ替えると上手く行く。


先にkを回す。
すると、iはi % k^2 == 0であるiからはじめてpずつ増やしていけばいいため、ループの回数が1/p倍になる。
すなわち、ループの回数は全体でm * p * e / p = m * e回になる。


これだと全てのmについてテーブルを作ってもO(divisor log(divisor))になる。
テーブルを合成するところでもO(divisor log(divisor))の計算量がかかるので、
全体でO(divisor log(divisor))の時間計算量で計算できる。

ソースコード

vector<ll> calc(int p, int e, int m){
	
	vector<ll> cnt(m); //cnt[i]: i - k^2にpが何回出てきたか
	vector<ll> table(m, inf); //n >= table[i] iif n = i mod mでp^eで割り切れる
	
	rep(k, p * e){
		
		ll k2 = (ll)k * k;
		
		for(int i = k2 % p; i < m; i += p) if(cnt[i] < e){
			
			int t = i - k2 % m;
			
			if(t < 0) t += m;
			if(t == 0) cnt[i] += e;
			else for(; t % p == 0; t /= p) cnt[i]++;
			
			if(cnt[i] >= e) table[i] = k2 + 1;
		}
	}
	return table;
}

class SparseFactorial {
	public:
	long long getCount(long long lo, long long hi, long long divisor) {
		
		ll ans = 0;
		int prev = 1;
		vector<ll> prev_table(1, 0);
		
		for(int p = 2; p <= divisor; p++) if(divisor % p == 0){
			
			int e = 0;
			int cur = 1;
			
			while(divisor % p == 0){
				divisor /= p;
				e++;
				cur *= p;
			}
			
			vector<ll> cur_table = calc(p, e, cur);
			vector<ll> next_table(cur * prev);
			
			rep(i, next_table.size()) next_table[i] = max(cur_table[i % cur], prev_table[i % prev]);
			
			swap(prev_table, next_table);
			prev *= cur;
		}
		
		rep(i, prev){
			ll l = max(lo, prev_table[i]), r = hi;
			l = (l - i + prev - 1) / prev * prev + i;
			if(l > r) continue;
			
			ans += (r - l) / prev + 1;
		}
		return ans;
	}
};