読者です 読者をやめる 読者になる 読者になる

競技プログラミングをするんだよ

ICPC国内予選突破を目標に一日一問題以上解いていきます。

m項間漸化式の高速なアルゴリズム

実装メモ

3/22:テンプレート追加しました。

m項間漸化式のk項目を高速に求める方法。通称「きたまさ法」と呼ばれるアルゴリズムについての(簡単な)解説。
考案者が蟻本作者の方でその名前からとられたみたいです。計算量はO(m^2 \log k)

入力

第k項:x_k = c_1x_{k-1}+c_2x_{k-2}+...c_mx_{k-m}=\sum_{i=1}^mc_ix_{k-i}
初項 :x_1...x_m
で定義されるm項間漸化式が与えられる。

出力

k項目の値を出力せよ。

このタイプの問題に対して強力なアルゴリズムです。
一般的なDPがO(k),演算子法を使ってコンパニオン行列を使った解法がO(m^3 \log k)で実装できるが、それよりも高速。

証明のような解法のようなもの

前提

1.k項目が定数sを用いて次のように次のような線形結合で表せる。
x_k=s_1x_{a+1}+s_2x_{a+2}...s_mx_{a+m}=\sum_{i=1}^ms_ix_{a+i} (n > a+m)
2.1が成り立つとき、k+n項目は次のように表せる。
x_{k+n}=s_1x_{a+n+1}+s_2x_{a+n+2}...s_mx_{a+n+m}=\sum_{i=1}^ms_ix_{a+n+i}
1の式の各項にnを足しただけなので直感的に分かると思います。

これを用いて解きます。

解法

1.u項目を初項の線形結合で表す。
x_u=s_1x_{1}+s_2x_{2}...s_mx_{m}=\sum_{i=1}^ms_ix_{i}
2.k+k項目の漸化式を前提2を使い、n=uとして次のように表す。
x_{u+u}=s_1x_{u+1}+s_2x_{u+2}...s_mx_{u+m}=\sum_{i=1}^ms_ix_{u+i}
3.2の式の各項のx_u+iについて、前提2の式をn=iとして次のように展開する。
x_{u+i}=s_1x_{i+1}+s_2x_{i+2}...s_mx_{i+m}=\sum_{j=1}^ms_jx_{i+j}
4.3の式を2の式に代入する。
x_{u+u}=x_{2u}=\sum_{i=1}^ms_ix_{u+i}=\sum_{i=1}^m\sum_{j=1}^ms_is_jx_{i+j}
5.x_{2u}=t_1x_1+t_2x_2+...t_mx_m+...t_{2m}x_{2m}とすると、
t_1...t_{2m}をO(m^2)で求められ、x_1からx_{2m}までの線形結合として表せる
6.5の式をx_1からx_{m}の線形結合の形に直し、x_{4u}を同様の手順で求める。

これを繰り返すことにより、
x_1,x_2,x_4,,,x_{2^{\log k}}を初項の線形結合としてO(m^2 \log k)で求めることができる。
前提2を使えば,x_u=\sum_{i=1}^ms_ix_i,x_v=\sum_{i=1}^mt_ix_iのとき、
x_{u+v}=\sum_{i=1}^m\sum_{j=1}^ms_it_jx_{i+j}として求められるので、
kをバイナリー展開し、ビットが立っているところを上の式を使って足し合わせてやればk項目を初項の線形結合で表すことができ
初項の線形結合であれば値を計算することができる。

TDPCの「フィボナッチ」がこれを用いないと解けない問題となっている。
また、このアルゴリズムはコンパニオン行列の時と同様演算子法を応用したものであるため演算子群が半環であれば足し算、掛け算でなくとも使用できる。

テンプレート

main,XはMボナッチ数列のN項目のMOD1,000,000,007を求めるためのもの

#include<iostream>
#include<fstream>
#include<sstream>
#include<string>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<list>
#include<algorithm>
#include<utility>
#include<complex>

using namespace std;

#define reE(i,a,b) for(auto (i)=(a);(i)<=(b);(i)++)
#define rE(i,b) reE(i,0,b)
#define reT(i,a,b) for(auto (i)=(a);(i)<(b);(i)++)
#define rT(i,b) reT(i,0,b)
#define rep(i,a,b) reE(i,a,b);
#define rev(i,a,b) for(auto (i)=(b)-1;(i)>=(a);(i)--)
#define itr(i,b) for(auto (i)=(b).begin();(i)!=(b).end();++(i))
#define LL long long
#define all(b) (b).begin(),(b).end()


/*
使い方
M項間漸化式のN項目を計算する。Nの最小は1。
a[N]=sum(c[i]*a[N-M+i-1])の形。
Mrが本体、Xは使わなくてもできる問題はある。(MODを使わないint,LLなど)
Xは半環を満たすものならなんでもよい、+,*のオーバーロードにMODをねじこむ。
コンストラクタに初項、係数、M,*の単位元、+の単位元の順で引数を与える。
あとはcalcにNを与えるだけ。
*/
#define MAX_LOGN 32
template <class T>
struct Mr{
	vector<T> first;
	vector<T> C;
	vector<vector<T>> bin;
	T zero,one;
	int M;
	//n(1,,,2M)をn(1,,,M)に修正、O(M^2)
	void form(vector<T> &n){
		rev(i, M + 1, 2 * M + 1){
			reE(j, 1, M)n[i - j] = (n[i - j] + (C[M - j] * n[i]));
			n[i] = zero;
		}
	}
	//lとrを足し合わせる、O(M^2)
	void add(vector<T> &l, vector<T> &r, vector<T> &ans){
		reE(i, 1, 2 * M)ans[i] = zero;
		reE(i, 1, M)reE(j, 1, M)ans[i + j] = (ans[i + j] + (l[i] * r[j]));
		form(ans);
	}
	//初期化、O(M*MAX_LOGN)
	Mr(const vector<T>& f,const vector<T>& c,int m,T e1,T e0){
		M = m;
		first.reserve(M + 1);C.reserve(M);
		zero = e0, one = e1;
		first.push_back(zero); 
		rT(i, M){ first.push_back(f[i]); C.push_back(c[i]); }
		bin.resize(MAX_LOGN);
		rT(i, MAX_LOGN)bin[i].resize(2*M+1);
		rE(i, 2*M)bin[0][i] = zero; bin[0][1] = one;
		reT(i,1, MAX_LOGN){
			add(bin[i - 1], bin[i - 1], bin[i]);
		}
	}
	//N項目の計算、戻り値がTの形であることに注意、O(M^2*logN)
	T calc(LL n){
		n--;
		vector<T> tmp,result = bin[0];
		for (int b = 0; n; b++,n>>=1)
			if (1 & n){ tmp = result; add(tmp, bin[b], result); }
		T ans = zero;
		reE(i, 1, M)ans = ans + (result[i] * first[i]);
		return ans;
	}
};
//テンプレート、デフォルトコンストラクタのオーバーロードを忘れない
#define MOD 1000000007
struct X{
	LL val;
	X(LL  v){ val = v; }
	X(){ val = 0; }
	LL operator=(const X &another){ return val = another.val; }
	LL operator*(const X &another)const{ return (val*another.val)%MOD; }
	LL operator+(const X &another)const{ return (val+another.val)%MOD; }
};


int main(void){
	LL n;
	int M;

	vector<X> A,B;
	cin >> M >> n;
	A.reserve(M);
	B.reserve(M);
	rT(i, M){ A.push_back(X(1)); }
	rT(i, M){ B.push_back(X(1)); }
	Mr<X> mr(A,B,M,X(1),X(0));
	cout << mr.calc(n).val << endl;
	return(0);
}