Codeforces 478(#273 Div2only) E. Wavy numbers

問題

wavy numberとは、数字を10進法で書いたときの数字をx1,x2,x3,...,xnとすると
x1<x2>x3<x4…が成り立つまたは、x1>x2<x3>x4…が成り立つ数のことを言う。

Nの倍数のwavy numberでK番目に小さいものを求めよ。
答えが存在しない場合や10^14より大きくなる場合は-1を出力せよ。

制約条件

N, K≦10^14

方針

場合分け。
Nが10^7以上のときは候補の数が10^7個程度しかないので全探索できる。


それ以外の場合は、K番目に小さい数を求める問題の典型手法である、
区間[l, r)内の条件を満たすものの個数を数えていく手法を使う。


区間内のwavy numberの個数がK個以上だったらその区間内に答えがあるのでその区間を詳しく調べ、
そうでなければ、区間内のwavy numberの個数をKから引いて、次の区間を見ていく。
区間を[x * 10^7, (x + 1) * 10^7)というように取れば、区間が重なりなく、全ての範囲をカバーしていて、小さい順に並んでいるため、これで正しい答えがわかる。


まずx = 0, すなわち[0, 10^7)の区間は愚直に調べてよい。
[x * 10^7, (x + 1) * 10^7)でx≧1の区間は、


まずxがwavy numberでなければ0個。そうでないときは、
下7桁は、「xより小さい数で始まり、最初に増加するwavy number」か、
「xより大きい数で始まり、最初に減少するwavy number」であって、
Nでmodを取ったときにN - x * 10^7 % Nになっている数であればよい。


最初に7桁のwavy numberを全列挙しておき、「増加するか減少するか」と、modごとにわけて配列に入れておけば、二分探索で上の条件を満たすものの個数が簡単にわかる。


leading zeroを考えていなかったのでその部分をアドホックに処理してしまっているが、
そこを含めて生成しておけばもっとコードが簡潔になる…orz

ソースコード

ll N, K;
vi wav[2][10000000];

void gen(int pos, int sum, int last, bool inc){
	
	if(pos && !(pos == 1 && inc)){
		int id = inc ^ (pos % 2);
		wav[1 - id][sum % N].pb(sum);
	}
	if(pos == 7) return;
	
	rep(i, 10) if(inc && last < i || !inc && last > i){
		if(i == 0 && pos == 0) continue;
		gen(pos + 1, sum * 10 + i, i, !inc);
	}
}
inline bool iswav(ll x){
	int n = 0, d[20];
	for(; x; x /= 10) d[n++] = x % 10;
	for(int i = 1; i < n - 1; i++){
		if(!(d[i - 1] < d[i] && d[i] > d[i + 1]) && !(d[i - 1] > d[i] && d[i] < d[i + 1])) return 0;
	}
	if(n == 2 && d[0] == d[1]) return 0;
	return 1;
}

int main(){
	cin >> N >> K;
	gen(0, 0, -1, 1);
	gen(0, 0, 10, 0);
	
	if(N >= 1e7){
		for(ll x = N; x <= 1e14; x += N) if(iswav(x) && --K == 0){
			cout << x << endl;
			return 0;
		}
		cout << -1 << endl;
		return 0;
	}
	rep(k, 2) rep(i, 1e7) sort(all(wav[k][i]));
	
	if(K <= wav[0][0].size() + wav[1][0].size()){
		vi v(wav[0][0].size() + wav[1][0].size());
		merge(all(wav[0][0]), all(wav[1][0]), v.begin());
		cout << v[K - 1] << endl;
		return 0;
	}
	K -= wav[0][0].size() + wav[1][0].size();
	
	for(int x = 1; x <= 1e7; x++) if(iswav(x)){
		
		int L = (x % 10 + 1) * (ll)1e6;
		int R = (x % 10) * (ll)1e6;
		
		int mod = (N - x * (ll)1e7 % N) % N;
		int sml = lower_bound(all(wav[1][mod]), R) - lower_bound(all(wav[1][mod]), (ll)1e6);
		int big = wav[0][mod].end() - lower_bound(all(wav[0][mod]), L);
		
		//leading zero
		sml += lower_bound(all(wav[0][mod]), (ll)1e6) - lower_bound(all(wav[0][mod]), (ll)1e5);
		
		if(x >= 10){
			if(x / 10 % 10 < x % 10) big = 0;
			else sml = 0;
		}
		
		if(sml + big < K){
			K -= sml + big;
			continue;
		}
		
		vi v;
		vi::iterator it, it2;
		
		if(sml){
			it = lower_bound(all(wav[1][mod]), (ll)1e6);
			for(; it < wav[1][mod].end() && *it < R; ) v.pb(*it++);
			//leading zero
			it = lower_bound(all(wav[0][mod]), (ll)1e5);
			for(; it < wav[0][mod].end() && *it < 1e6; ) v.pb(*it++);
		}
		if(big){
			it = lower_bound(all(wav[0][mod]), L);
			for(; it != wav[0][mod].end(); ) v.pb(*it++);
		}
		sort(all(v));
		cout << x * (ll)1e7 + v[K - 1] << endl;
		return 0;
	}
	cout << -1 << endl;
	
	return 0;
}