FFT-NTT

捲積方法 OAO

本篇的重點應該是放在理解與使用與競程相關的捲積,其中會利用到FFT或NTT加速
這邊先定義一下捲積是什麼

(ab)x=i+j=xaibj

實際上就等價於我們常見的多項式乘法
naive的做法是O(n2),顯然不夠令人滿意

1
2
3
4
5
for(int i = 0; i < A.size(); i++) {
	for(int j = 0; j < B.size(); j++) {
		res[i+j] += A[i] * B[j];
	}
}

DFT

先假設我們有兩個多項式
A(x)=aixi,B(x)=bixiC(x)=A(x)B(x)
除了上述利用分配律乘開以外
因為C的次數已經可以確定
我們也可以在A(x)B(x)找出n個相異的點,相乘之後再利用插值法代入得到C
其中n=degC+1
也就是

[1x0x02x0n11x1x12x1n11xn1xn12xn1n1][a0a1an1]=[A(x0)A(x1)A(xn1)]

[1x0x02x0n11x1x12x1n11xn1xn12xn1n1][b0b1bn1]=[B(x0)B(x1)B(xn1)]

[1x0x02x0n11x1x12x1n11xn1xn12xn1n1]1[A(x0)B(x0)A(x1)B(x1)A(xn1)B(xn1)]=[c0c1cn1]

最後一步直接高斯消去或用拉格朗日/牛頓插值法可以做到O(n2)

上述步驟稱為DFT(對序列的版本叫離散傅立葉變換,與使用積分的連續傅立葉變換相對)和IDFT
但是這樣根本沒有改進多少複雜度啊?
邁向快速傅立葉變換的鑰匙是利用複數,取特定的某些x讓我們能夠分治

Root of Unity

首先先來介紹單位根ω是使得

ωn=10i<j<n,ωiωj

的數

複習一下歐拉公式eix=cosx+isinx
習慣上可以取ωn=e2πin(下標是表示n是最小的i使ωi=1,或者說ord(ωn)=n)

404的啦QQ

推薦觀賞3B1B系列
https://youtu.be/v0YEaeIClKY
https://youtu.be/mvmuCPvRoWQ

引理們

Lemma a.

ωdndk=(e2πidn)dk=(e2πin)k=ωnk

Lemma b.

ωnn2=ω2=eiπ=1

Cooley-Tukey FFT algorithm

先假設n是2的冪次,然後下面提到的i都只是index
將DFT中的xi取值為ωni,可以知道我們要算的就是對i[0,n1]
yi=j=0n1aj(ωni)j
把右式的奇數項和偶數項分開處理(這邊是原理的精華)
yi=j=0n1aj(ωni)j=j=0n21a2j(ωni)2j+j=0n21a2j+1(ωni)2j+1=j=0n21a2j(ωn2i)j+ωnij=0n21a2j+1(ωn2i)j=Feven(i)+ωniFodd(i)

其中Feven,Fodd分別是以奇數和偶數項FFT得到的東西,可以遞迴求解
雖然以n2的長度遞迴只能得到i[0,n21]的答案
不過FevenFodd都有週期n2,再由Lemma b.可以簡化成

for 0i<n2,{yi=Feven(i)+ωniFodd(i)yi+n2=Feven(i)ωniFodd(i)

時間複雜度有T(n)=2T(n/2)+O(n),由主定理可知T(n)=O(nlogn)
要將FFT一言以概之,大概就是利用分治法將多項式轉換成點值表示吧
附上遞迴版的參考程式碼,雖然迭代版通常效率較好不過遞迴版有助於理解

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
const double PI = acos(-1);
typedef complex<double> cd;
vector<cd> FFT(const vector<cd> &F) { // assume F.size() == 2^k
	if(F.size() == 1) return F; // base case (important)
	vector<cd> rec[2], ans;
	for(int i = 0; i < F.size(); i++) rec[i&1].push_back(F[i]);
	rec[0] = FFT(rec[0]);
	rec[1] = FFT(rec[1]);
	double theta = -2*PI / F.size();
	cd now = 1, omega(cos(theta), sin(theta));
	ans.resize(F.size());
	for(int i = 0; i < F.size()/2; i++) {
		ans[i] = rec[0] + now * rec[1];
		ans[i+F.size()/2] = rec[0] - now * rec[1];
	}
	return ans;
}

Inverse-FFT

那麼要怎麼做IFFT(傅立葉變換的逆變換),也就是把點值表示轉換回係數呢?
FFT可以寫成矩陣的形式,也就是

[11111ωω2ωn11ωn1(ωn1)2(ωn1)n1][c0c1cn1]=[C(1)C(ω)C(ωn1)]
左項有一個范德蒙矩陣V=[ωij] (0-base)
事實上其反矩陣就是V=[1nωij]

說明:
[VV]i,j=k=0n1Vi,kVk,j=1nk=0n1ωk(ij)
i=j時顯然為1
ij利用等比級數公式可以知道總和為0
故相乘的結果是單位矩陣

可以發現我們只需要把FFT的ω改成倒數,最後再除上n就是IFFT所需要的
因為FFT和IFFT的相似性,我們可以將程式碼整合如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
const double PI = acos(-1);
typedef complex<double> cd;
vector<cd> FFT(const vector<cd> &F, bool inv) { // assume F.size() == 2^k
	if(F.size() == 1) return F; // base case (important)
	vector<cd> rec[2];
	for(int i = 0; i < F.size(); i++) rec[i&1].push_back(F[i]);
	rec[0] = FFT(rec[0],inv);
	rec[1] = FFT(rec[1],inv);
	double theta = (inv ? 1 : -1) * 2 * PI / F.size();
	cd now = 1, omega(cos(theta), sin(theta));
	vector<cd> ans(F.size());
	for(int i = 0; i < F.size()/2; i++) {
		ans[i] = rec[0][i] + now * rec[1][i];
		ans[i+F.size()/2] = rec[0][i] - now * rec[1][i];
		now *= omega;
	}
	if(inv) for(int i = 0; i < ans.size(); i++) ans[i] /= 2;
	return ans;
}

Convolution

有了FFT和IFFT兩個工具,我們要做捲積就很簡單了

  1. 確定兩個多項式相乘的次數,並且選擇一個足夠大的n=2k(後面可以補0)
  2. 利用Cooley-Tukey演算法求出A,B的傅立葉變換A^,B^
  3. A^,B^在對應位置兩兩相乘得到C^(可能叫Hadamard Product吧)
  4. 再利用Cooley-Tukey演算法求出C
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
vector<cd> A{1,3,4};
vector<cd> B{1,2,5};
signed main() {
    int n = 1<<__lg(A.size()+B.size())+1;
    A.resize(n);
    B.resize(n);
    A = FFT(A,0);
    B = FFT(B,0);
    vector<cd> C(n);
    for(int i = 0; i < n; i++) C[i] = A[i]*B[i];
    C = FFT(C,1);
    for(int i = 0; i < n; i++) cout << C[i].real() << " \n"[i==n-1];
}

Iterative Version

迭代的版本不但簡單執行時間又快,值得記一下

觀察遞迴的情況
可以看到我們每次都是將一個序列的偶數項放前面做,奇數項放後面做再合併
這可以想成將最低位的0/1移到最高位,例如
100010111000101111101111111010101000101010
重複執行了把最低位移到最高位的動作k=log2n次之後
原本放在i的位置的數字的index最後會被放到j的地方,其中jik位二進位數的反轉
也就是說我們可以一開始就把所有數字放到他在遞迴樹中對應的位置,再一層一層往上合併

404的啦QQ

那要拿哪些合併呢?其實每個相鄰的兩塊的相同位置對應的就是FevenFodd,組合算出yi之後要填的地方也是那兩格
剩下的就是看code理解了吧…OwO?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
typedef complex<double> cd;
void FFT(cd F[], int n, bool inv) { // in-place FFT, also assume n = 2^k
	for(int i = 0, j = 0; i < n; i++) {
		if(i < j) swap(F[i], F[j]);
		// magic! (maintain j to be the bit reverse of i)
		for(int k = n>>1; (j^=k) < k; k>>=1);
	}
	for(int step = 1; step < n; step <<= 1) {
		double theta = (inv ? 1 : -1) * PI / step;
		cd omega(cos(theta), sin(theta));
		for(int i = 0; i < n; i += step*2) {
			cd now(1,0);
			for(int j = 0; j < step; j++) {
				cd a = F[i+j];
				cd b = F[i+j+step] * now;
				F[i+j] = a+b;
				F[i+j+step] = a-b;
				now *= omega;
			}
		}
	}
	if(inv) for(int i = 0; i < n; i++) F[i] /= n;
}

NTT

注意到我們可以實行分治的關鍵就是存在一個ω使得

ωn=10i<j<n,ωiωj

現在我們想要在模一個質數p下做類似的事

費馬小定理表明

(a,p)=1,aφ(p)1(modp)

如果有原根g使得

0i<j<φ(p),gigj(modp)

那麼ω的選擇就很簡單了,也就是ωngφ(p)n
容易驗證ωn滿足上面的性質

這樣做必須滿足n|φ(p),而若使用Cooley-Tukey演算法的話n會是2的冪次
也就是說若φ(p)=p1=t2k,其中t是奇數
那對這個p來說可行的n的範圍最多就是2k
這也是為什麼NTT的模數常常都是那些數字的原因
因為p1必須在二進位下有很多個後綴0

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
const int64_t MOD = 998244353, G = 3;
int64_t modpow(int64_t e, int64_t p, int64_t m) {
	int64_t r = 1;
	for(; p; p>>=1) {
		if(p&1) r = r*e%m;
		e = e*e%m;
	}
	return r;
}
void NTT(int64_t F[], int n, bool inv) { // assume n = 2^k!
	for(int i = 0, j = 0; i < n; i++) {
		if(i < j) swap(F[i], F[j]);
		for(int k = n>>1; (j^=k) < k; k>>=1);
	}
	for(int step = 1; step < n; step <<= 1) {
		//may preprocess to boost
		int64_t omega = modpow(G, (MOD-1) / (step*2), MOD);
		if(inv) omega = modpow(omega, MOD-2, MOD);
		for(int i = 0; i < n; i += step*2) {
			int64_t now = 1;
			for(int j = 0; j < step; j++) {
				cd a = F[i+j];
				cd b = F[i+j+step] * now % MOD;
				// reduce the use of % operator
				F[i+j] = (a+b < MOD ? a+b : a+b-MOD);
				F[i+j+step] = (a-b<0 ? a-b+MOD : a-b);
				now = now*omega%MOD;
			}
		}
	}
	if(inv) {
		int64_t invn = modpow(n, MOD-2, MOD);
		for(int i = 0; i < n; i++) F[i] = F[i]*invn%MOD;
	}
}

→NTT模數表←

中國剩餘?

一個模數合不合適取決於最後答案的大小
兩個值域c、長度n的多項式相乘,得出來的乘積的值域最多會是nc2
如果不會超過模數的話就可以直接使用
但如果會超過怎麼辦?
挑選更大的模數沒什麼用,因為相乘起來可能就超過long long了
這時我們就必須做多次NTT再用中國剩餘定理合併

如果真實的答案不是指數或階乘那種直接爆炸的數值
甚至還可以用來對任意數字取模(?)

End

FFT與NTT的利用其實滿少的,大部分不是大數乘法就是生成函數,以後有時間再放一篇講好了