Introduction to Fast Fourier Transform (FFT) and Number Theoretic Transform (NTT) Algorithms, Principles, and C++ Implementation

The function f(t) in the time domain, after undergoing the Fourier Transform, becomes F(w) in the frequency domain. It represents the signal in the time domain as a weighted sum of different frequency sine curves.

\[ F(\omega)=\mathcal{F}[f(t)]=\int\limits_{-\infty}^\infty f(t)e^{-iwt}dt\]

The inverse Fourier transform converts F(w) in the frequency domain back to the time domain as a function f(t), where \(f(t)\) is the original function and \(F(w)\) is the image function. The original function and the image function form a Fourier transform pair.

\[ f(t)=\mathcal{F^{-1}}[F(w)]=\frac 1 {2\pi}\int\limits_{-\infty}^\infty F(w)e^{iwt}dw\]

The Discrete Fourier Transform (DFT) represents the discrete forms of Fourier transforms in both time and frequency domains.

\[ x_n=\sum_{k=0}^{N-1}X_ke^{\frac{2\pi i}{N} kn},n=0,...,N-1\]

Similarly, there is the Inverse Discrete Fourier Transform (IDFT):

\[ x_n=\frac 1 N\sum_{k=0}^{N-1}X_ke^{-\frac{2\pi i}{N} kn},n=0,...,N-1\]

Fast Fourier Transform (FFT)

Fast Fourier Transform (FFT) utilizes the divide-and-conquer approach to decompose DFT into smaller DFTs for computation, achieving an algorithmic time complexity of \(O(n\log n)\). It is commonly used to accelerate convolution and polynomial multiplication.

Polynomial Multiplication

For \(C(x)=A(x)B(x)=\sum_{i=0}^{n-1}\sum_{j=0}^{i}a_jb_{i-j}x^i\), where the highest degree of A plus the highest degree of B is n-1, direct computation is \(O(n^2)\). Expressing them in point-value form:

\[ A(x):{(x_0,y_0),(x_1,y_1),...,(x_{n-1},y_{n-1})}\\\\ B(x):{(x_0,y_0'),(x_1,y_1'),...,(x_{n-1},y_{n-1}')}\\\\ C(x):{(x_0,y_0y_0'),(x_1,y_1y_1'),...,(x_{n-1},y_{n-1}y_{n-1}')}\]

This allows computing C(x) in \(O(n)\) time.

Converting them into point-value form requires substituting n different x values. Using the Horner's method, substituting one x requires \(O(n)\) time:

\[ A(x_0)=a_0+x_0(a_1+x_0(a_2+...+x_0(a_{n-2}+x_0(a_{n-1}))...))\]

To transform from point-value expression to coefficient expression, interpolation is employed. By substituting n distinct values into A, a unique coefficient expression is determined (proof available in "Introduction to Algorithms").

FFT enables achieving the conversion between coefficients to point values and vice versa in \(O(n\log n)\) time.

n-th Unit Root

The n-th root of unity \(w\) satisfies \(w^n=1\) for complex numbers. The principal n-th root is \(w_n=e^{\frac {2\pi } ni}\), and all other n-th roots are powers of \(w_n\).

Complex numbers in exponential form are defined as:

\[ e^{ui}=\cos(u)+i\sin(u)\]

Thus, n complex roots are evenly distributed on the unit circle in the complex plane with the origin as the center.

Properties of Roots of Unity

  1. Elimination Lemma: \(w^{dk}_{dn}=w^k_n\), proved by substituting the definition. Corollary: \(w^{n/2}_n=w_2=-1\).

  2. Halving Lemma: \((w^{k+n/2}_n)^2=(w_n^k)^2\), proved by expanding. Corollary: \(w^{k+n/2}_n=-w^{k}_n\).

For convenience in divide-and-conquer, we extend n to the next power of 2 (appending zeros to the polynomial). The process of substituting n nth roots into A(x) corresponds to performing DFT on the coefficient vector.

\[ A(w_n^k)=\sum_{j=0}^{n-1}a_jw^{kj}_n,k=0,1,..,n-1\]

Considering odd and even exponents:

\[ A_{even}(x)=a_0+a_2x+a_4x^2+...\\ A_{odd}(x)=a_1+a_3x+a_5x^2+...\\\]

Then,

\[ A(x)=A_{even}(x^2)+xA_{odd}(x^2)\\ A(w_n^k)=A_{even}(w_{n}^{2k})+w_n^kA_{odd}(w_n^{2k})\\\]

When \(k< n/2\), using the elimination lemma,

\[ A_{even}(w_n^{2k})+w_n^kA_{odd}(w_n^{2k})=A_{even}(w_{n/2}^k)+w_n^kA_{odd}(w_{n/2}^k)\]

This is equivalent to performing DFT on two polynomials \(A_{even}(x)\) and \(A_{odd}(x)\) of length n/2.

The part with exponents not less than \(n/2\) corresponds to \(k+n/2\), and by the halving lemma,

\[ A_{even}((w_n^{k+n/2})^2)+w_n^{k+n/2}A_{odd}((w_n^{k+n/2})^2)=A_{even}(w_{n/2}^k)-w_n^kA_{odd}(w_{n/2}^k)\]

Thus, this process is recursive, and \(T(n)=2T(n/2)+O(n)\).

Recursive FFT (C++ Code)

#include <complex>
#include <iostream>
using namespace std;

typedef complex<double> CD;
const double pi = acos(-1);
CD tmp[N],epsilon[N];
void init_epsilon(int n){
    for(int i = 0; i < n; ++i){
        epsilon[i] = CD(cos(2.0 * pi * i / n), sin(2.0 * pi * i / n)); 
        arti_epsilon[i] = conj(epsilon[i]);
    }
}
void recursive_fft(int n, CD* A,int offset, int step, CD* w){
    if(n==1)return;
    int m=n>>1;
    recursive_fft(m,A,offset,step<<1,w);
    recursive_fft(m,A,offset+step,step<<1,w);
    for(int k=0;k<m;++k){
        int pos=2*step*k;
        tmp[k]  =A[pos+offset]+w[k*step]*A[pos+offset+step];
        tmp[k+m]=A[pos+offset]-w[k*step]*A[pos+offset+step];
    }
    for(int i=0;i<n;++i)
        A[i*step+offset]=tmp[i];
}

Iterative Implementation

However, recursion requires significant space. How can we implement an iterative version?

Observing the recursive process, the first step:

0(000)2(010)4(100)6(110),1(001)3(011)5(101)7(111)

The second step:

0 (000)4(100),2(010) 6(110),1(001)5(101)3(011)7(111)

Reversing the binary indices:

000,001,010,011,100,101,110,111

Corresponding to 0,1,2,3,...

To obtain the numbers from 0 to \(2^{n-1}\) after flipping the bits, you can maintain a reversed number and perform an increment operation by adding 1 to its highest bit, carrying over to the lower bits. This flipping and incrementing process can be represented by the following function:

int reverse_add_1(int x, int bit_length){
    for(int l=bit_length>>1;(x^=l)<l;l>>=1);
    return x;
}

We start the calculation from i=1, where j=0 represents the number obtained after flipping 0. By applying the reverse_add_1 operation to j, we get the number corresponding to the flipped version of 1. We only swap the values in array A when the current number is greater than the number obtained after flipping. This prevents redundant swaps.

void bit_reverse(CD* A,int n){
    for(int i=1,j=0;i<n;++i){
        j = reverse_add_1(j, n);
        if(i>j)swap(A[i],A[j]);
    }
}

Once we obtain the array after flipping, we can use it to implement iterative FFT (Fast Fourier Transform).

void fft(CD* A, int n, CD* w){
    bit_reverse(A,n);
    for(int i=2;i<=n;i<<=1)//From bottom to top, i is the step size for each layer, or in other words, the length of the subproblem.
      for(int j=0,m=i>>1;j<n;j+=i)//j is the offset, or rather the starting point of each subproblem in this layer.
        for(int k=0;k<m;++k){//k=0..i/2, compute the values for the k-th and (k+i/2)-th elements of the subproblem.
          CD b=w[n/i*k]*A[j+m+k];
          A[j+m+k]=A[j+k]-b;
          A[j+k]+=b;
        }
}

Inverse Discrete Fourier Transform (IDFT)

The IDFT converts point values back into coefficient expressions, which is equivalent to solving the equation system \(\vec{y} = V_n \vec{a}\), where \(\vec{a}\) is the coefficient vector. \(V_n\) is the Vandermonde matrix defined as follows:

\[ \begin{bmatrix} 1&1&1&1&\cdots &1\\ 1&w;_n&w;_n^2&w;_n^3&\cdots &w;_n^{n-1}\\ 1&w;_n^2&w;_n^4&w;_n^6&\cdots &w;_n^{2(n-1)}\\ \vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\ 1&w;_n^{n-1}&w;_n^{2(n-1)}&w;_n^{3(n-1)}&\cdots &w;_n^{(n-1)(n-1)}\\ \end{bmatrix}\]

Therefore, \(V_n^{-1}\vec{y} = \vec{a}\).

Let \([V_n^{-1}]_{kj} = w^{-kj}_n/n\), and we need to prove \(V_n^{-1}V_n = I_n\):

\[ [V_n^{-1}V_n]_{jj'}=\sum_{k=0}^{n-1}(w_n^{-kj}/n)(w_n^{kj'})=\sum_{k=0}^{n-1}w_n^{k(j'-j)}/n = \left\{ \begin{aligned} 0,j'\neq j \\ 1,j'=j \end{aligned} \right.\]

Thus, by replacing \(w_n\) with \(w_n{-1}\) and dividing each element of the result by \(n\), we can compute \(DDF_n{-1}\).

Solving High-Precision Multiplication with FFT, C++ Code

51 Nod 1028 大数乘法 V2

Note that FFT introduces precision errors due to the use of cosine, sine, and floating-point calculations. Therefore, a correction factor of 0.5 is added.

#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(int i=l;i<r;++i)
#define per(i,l,r) for(int i=r-1;i>=l;--i)
#define SZ(x) ((int)(x).size())

typedef double dd;
typedef complex<dd> CD;
const dd PI=acos(-1.0);
const int L=18,N=1<<L;

CD eps[N],inv_eps[N],f[N],g[N];
void init_eps(int p){
    rep(i,0,p)eps[i]=CD(cos(PI*i*2/p),sin(PI*i*2/p)),inv_eps[i]=conj(eps[i]);
}
void fft(CD p[], int n, CD w[]){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                CD b=w[n/i*k]*p[j+m+k];
                p[j+m+k]=p[j+k]-b;
                p[j+k]+=b;
            }
}
int ans[N];
int main(){
    string a,b;
    cin>>a>>b;
    int n=max(SZ(a),SZ(b)),p=1;
    while(p<n)p<<=1;p<<=1;
    rep(i,0,p)f[i]=g[i]=0;
    n=0;per(i,0,SZ(a))f[n++]=a[i]-'0';
    n=0;per(i,0,SZ(b))g[n++]=b[i]-'0';

    init_eps(p);
    fft(f,p,eps);fft(g,p,eps);
    rep(i,0,p)f[i]*=g[i];
    fft(f,p,inv_eps);

    int t=0;
    rep(i,0,p){
        ans[i]=t+(f[i].real()+0.5)/p;
        if(ans[i]>9){t=ans[i]/10;ans[i]%=10;}
        else t=0;
    }
    bool flag=0;
    per(i,0,p)if(ans[i]||flag){
        printf("%d",ans[i]);flag=1;
    }
    if(flag==0)puts("0");
    return 0;
}

Number-Theoretic Transform (NTT)

To address the precision errors caused by FFT, a method that operates in the modulo domain using only integer operations was developed. This method is called the Number-Theoretic Transform (NTT).

Primitive Root

Let \(m\) be a positive integer and \(a\) be an integer. The smallest positive integer \(r\) satisfying the congruence \(a^r \equiv 1 \mod m\) is called the exponent of \(a\) modulo \(m\).

If the exponent of \(a\) modulo \(m\) is equal to \(\varphi(m)\) (Euler's totient function), then \(a\) is called a primitive root modulo \(m\).

Properties of Primitive Roots

For a prime number \(p\), the primitive root \(g\) satisfies \(g0, g1, g2, ..., g{p-1}\) forms a reduced residue system modulo \(p\).

Let \(p = c \cdot 2^k + 1\). For powers of 2, \(n | (p-1)\). Let

\[ g_n = g^{\frac{p-1}{n}}\]

Consider the properties of primitive roots required in FFT:

  1. \(w_n^k (0 \leq k < n)\) are distinct, ensuring the legality of point value representation.

  2. \(\omega_{n} {2k} = \omega_{n/2}k\), used for divide-and-conquer.

  3. \(\omega_n {k + \frac{n}{2}} = -\omega_n k\), used for divide-and-conquer.

  4. \(\sum_{j=0}^{n-1}(w_n^k)^j = 0\) only when \(k \neq 0\), used for inverse transform.

Do primitive roots possess these properties?

  1. Let \(p = c \cdot 2^k + 1\) and take \(n = 2^m\). Then \(g_n^k (0 \leq k<n) = g^{\frac{(p-1)k}{n}} = g^0, g^{c \cdot 2^{k-m}}, g^{2c \cdot 2^{k-m}}, ..., g^{(p-1)c \cdot 2^{k-m}}\) spans a set of distinct elements.

Here is a little piece of knowledge: If \(a_1, a_2, \ldots, a_p\) form a complete residue system modulo p, then \(aa_1+b, aa_2+b, \ldots, aa_p+b\) also form a complete residue system modulo p. To prove this, suppose \(aa_j+ba\equiv a_i+b \mod p\). Then we have \(a_i \equiv a_j \mod p\), which leads to a contradiction.

  1. \(g_n{2k} = g{\frac{2k(p-1)}{n}} = g{\frac{k(p-1)}{n/2}} = g_{n/2}k\)

  2. \(g_n^n = g^{p-1} \equiv 1 \mod p \Rightarrow g^{\frac{p-1}{2}} \equiv \sqrt{1} \mod p\) Since \(g^k\) are distinct, \(g^{\frac{p-1}{2}} \equiv -1 \mod p\). Thus, \(g^{k+\frac{n}{2}}_n = g_n^k \cdot g_n^{\frac{n}{2}} = g_n^k \cdot g^{\frac{p-1}{2}} = -g_n^k\).

  3. When \(k \neq 0\), \(\sum_{j=0}^{n-1}(g_n^k)^j = \frac{1 - (g_n^k)^n}{1 - g_n^k}\). Using property 3, \(g_n^n = 1\), so \(1 - (g_n^k)^n = 1 - (g_n^n)^k = 0\). Otherwise, it equals \(n\).

Hence, replacing the unit roots with primitive roots in FFT results in NTT. To accommodate arbitrary modulus, the Chinese Remainder Theorem (CRT) can be employed to combine the results.

C++ code of NTT

Still the above problem for high precision multiplication.

#include <bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(int i=l;i<r;++i)
#define per(i,l,r) for(int i=r-1;i>=l;--i)
#define SZ(x) ((int)(x).size())

typedef long long LL;
const int L=18,N=1<<L;

const LL C = 479;
const LL P = (C << 21) + 1;
const LL G = 3;

LL qpow(LL a, LL b, LL m){
    LL ans = 1;
    for(a%=m;b;b>>=1,a=a*a%m)if(b&amp;1)ans=ans*a%m;
    return ans;
}

LL eps[N],inv_eps[N],f[N],g[N];
void init_eps(int n){
    LL t=(P-1)/n, invG=qpow(G,P-2,P);
    rep(i,0,n) eps[i]=qpow(G,t*i,P),inv_eps[i]=qpow(invG,t*i,P);
}
void fft(LL p[], int n, LL w[]){
    for(int i=0,j=0;i<n;++i){
        if(i>j)swap(p[i],p[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=2;i<=n;i<<=1)
        for(int j=0,m=i>>1;j<n;j+=i)
            rep(k,0,m){
                LL b=w[n/i*k]*p[j+m+k]%P;
                p[j+m+k]=(p[j+k]-b+P)%P;
                p[j+k]=(p[j+k]+b)%P;
            }
}
LL ans[N];
int main(){
    string a,b;
    cin>>a>>b;
    int n=max(SZ(a),SZ(b)),p=1;
    while(p<n)p<<=1;p<<=1;
    rep(i,0,p)f[i]=g[i]=0;
    n=0;per(i,0,SZ(a))f[n++]=a[i]-'0';
    n=0;per(i,0,SZ(b))g[n++]=b[i]-'0';

    init_eps(p);
    fft(f,p,eps);fft(g,p,eps);
    rep(i,0,p)f[i]=f[i]*g[i]%P;
    fft(f,p,inv_eps);

    int t=0;
    LL invp=qpow(p,P-2,P);
    rep(i,0,p){
        ans[i]=(t+f[i]*invp%P)%P;
        if(ans[i]>9){t=ans[i]/10;ans[i]%=10;}
        else t=0;
    }
    bool flag=0;
    per(i,0,p)if(ans[i]||flag){
        printf("%lld",ans[i]);flag=1;
    }
    if(flag==0)puts("0");
    return 0;
}


Blog Comments powered by Disqus.