Codeforces 338C Divisor Tree

問題

相異なる2以上のn個の整数a[i]を全て含むようなdivisor treeのうち、頂点数が最小のものの頂点数を求めよ。
ただしdivisor treeとは以下のような木である。

  • 各頂点には整数が書かれている
  • 葉の頂点の整数は素数
  • そうでない頂点の整数は、その子の頂点の整数の積

制約条件

n≦8
a[i]≦10^12

方針

いろいろなやり方があるっぽい。
自分は適当にDPした。


dp[i][bit]を、一番上がiであり、bitで表される集合の数字を含む木の最小サイズとしてこれを更新するDP.


ただし、更新するときにもう一回小さいDPをする必要があって、
dp2[bit1][bit2]が、直接の子としてbit1を持っていて、bit2の数字をカバーするときのサイズの和の最小値。
dp[i][bit] = min(dp2[bit1][bit] + 1 + cnt(a[i] / prod[bit1]))である。
ただし、prod[i]は、a[]を、iの集合の部分だけあつめたときの積。
cnt(x)は、xを素因数分解したときの素因数の個数。


それで、最後にもう一回DPをする。
dp3[i] = iの集合を作るのに必要な森における最小サイズ
で、部分集合列挙のDPをして、(3^nじゃなくて4^nの手抜きでいい)
dp3[(1 << n) - 1] + 1が答え。


なんか計算量がすごい怪しいんだけど枝刈りされて一瞬(最悪30msくらい)で通るぽい。
cnt()はさすがにメモ化しないとダメな気がする。

ソースコード

int n;
ll a[8], dp[8][1 << 8];
ll dp2[1 << 8][1 << 8], prod[1 << 8];
map<ll, ll> memo;

ll cnt(ll n){
	if(memo.count(n)) return memo[n];
	
	ll m = n, res = 0;
	for(ll j = 2; j * j <= m; j++) if(m % j == 0){
		while(m % j == 0){
			m /= j;
			res++;
		}
	}
	if(m > 1) res++;
	return memo[n] = res;
}

int main(){
	cin >> n;
	rep(i, n) cin >> a[i];
	sort(a, a + n);
	
	rep(i, 1 << n){
		prod[i] = 1;
		rep(j, n) if(i & 1 << j){
			if(1.0 * prod[i] * a[j] > 1e18) prod[i] = 1e18;
			else prod[i] *= a[j];
		}
	}
	
	ll ans = inf;
	rep(i, n){
		rep(j, 1 << n) dp[i][j] = inf;
		dp[i][1 << i] = 1 + cnt(a[i]);
		if(dp[i][1 << i] == 2) dp[i][1 << i] = 1;
	}
	
	rep(i, n){
		
		rep(j, 1 << n) rep(k, 1 << n) dp2[j][k] = inf;
		dp2[0][0] = 0;
		
		rep(j, i){
			rep(k, 1 << j) rep(l, 1 << j+1){
				if(dp2[k][l] >= inf) continue;
				if(a[i] % prod[k | 1 << j]) continue;
				
				rep(m, 1 << j+1) if(dp[j][m] < inf){
					dp2[k | 1 << j][l | m] =
					min(dp2[k | 1 << j][l | m], dp2[k][l] + dp[j][m]);
				}
			}
		}
		rep(j, 1 << i) rep(k, 1 << i) if(dp2[j][k] < inf){
			dp[i][k | 1 << i] = min(dp[i][k | 1 << i], 1 + dp2[j][k] + cnt(a[i] / prod[j]));
		}
	}
	
	ll dp3[1 << n];
	rep(i, 1 << n) dp3[i] = inf;
	rep(i, n) rep(j, 1 << n) dp3[j] = min(dp3[j], dp[i][j]);
	ans = min(ans, dp3[(1 << n) - 1]);
	
	rep(i, 1 << n) rep(j, 1 << n) dp3[i] = min(dp3[i], dp3[j] + dp3[i ^ j]);
	ans = min(ans, dp3[(1 << n) - 1] + 1);
	
	cout << ans << endl;
	
	return 0;
}