卷积

给定向量: a=(a0,a1,...,an1) <script type="math/tex" id="MathJax-Element-1">a=(a_0,a_1,...,a_{n-1})</script>, b=(b0,b1,...,bn1) <script type="math/tex" id="MathJax-Element-2">b=(b_0,b_1,...,b_{n-1})</script>

向量和: a+b=(a0+b0,a1+b1,...,an1+bn1) <script type="math/tex" id="MathJax-Element-3">a+b=(a_0+b_0,a_1+b_1,...,a_{n-1}+b_{n-1})</script>
数量积(内积、点积): ab=a0b0+a1b1+...+an1bn1 <script type="math/tex" id="MathJax-Element-4">a·b=a_0b_0+a_1b_1+...+a_{n-1}b_{n-1}</script>
卷积 ab=(c0,c1,...,c2n2) <script type="math/tex" id="MathJax-Element-5">a \otimes b=(c_0,c_1,...,c_{2n-2})</script>,其中 ck=i+j=k(aibj) <script type="math/tex" id="MathJax-Element-6">c_k=\sum_{i+j=k}(a_ib_j)</script>

例如: cn1=a0bn1+a1bn2+...+an2b1+an1b0 <script type="math/tex" id="MathJax-Element-7">c_{n-1}=a_0b_{n-1}+a_1b_{n-2}+...+a_{n-2}b_1+a_{n-1}b_0</script>

卷积的最典型的应用就是多项式乘法(多项式乘法就是求卷积)。以下就用多项式乘法来描述、举例卷积与DFT。

关于多项式

对于多项式 A(x) <script type="math/tex" id="MathJax-Element-8">A(x)</script>,系数为 ai <script type="math/tex" id="MathJax-Element-9">a_i</script>,设最高非零系数为 ak <script type="math/tex" id="MathJax-Element-10">a_k</script>,则其次数就是 k <script type="math/tex" id="MathJax-Element-11">k</script>,记作degree(A)=k<script type="math/tex" id="MathJax-Element-12">degree(A)=k</script>。任何大于 k <script type="math/tex" id="MathJax-Element-13">k</script>的整数都是A(x)<script type="math/tex" id="MathJax-Element-14">A(x)</script>的次数界

多项式的系数表达方式 A(x)=a0+a1x+a2x2+...+an1xn1=n1i=0ajxj <script type="math/tex" id="MathJax-Element-15">A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}=\sum^{n-1}_{i=0} a_jx^j</script>(次数界为 n <script type="math/tex" id="MathJax-Element-16">n</script>)。
则多项式的系数向量即为a=(a0,a1,...,an1)<script type="math/tex" id="MathJax-Element-17">a=(a_0,a_1,...,a_{n-1})</script>。
多项式的点值表达方式 {(x0,y0),(x1,y1),...,(xn1,yn1)} <script type="math/tex" id="MathJax-Element-18">\{(x_0,y_0),(x_1,y_1),...,(x_{n-1},y_{n-1})\}</script>,其中 xk <script type="math/tex" id="MathJax-Element-19">x_k</script>各不相同, yk=A(xk) <script type="math/tex" id="MathJax-Element-20">y_k=A(x_k)</script>。

离散傅里叶变换(DFT)

离散傅里叶变换(Discrete Fourier Transform,DFT)。在信号处理很重要的一个东西,这里物理意义以及其他应用暂不予理睬。在多项式中,DFT就是系数表式转换成点值表示的过程。

快速傅里叶变换(FFT)

快速傅里叶变换(Fast Fourier Transformation,FFT):快速计算DFT的算法,能够在 O(nlogn) <script type="math/tex" id="MathJax-Element-41">O(n\log n)</script>的时间里完成DFT。FFT只是快速的求DFT的方法罢了,不是一个新的概念。 在ACM-ICPC竞赛中, FFT算法常被用来为多项式乘法加速。FFT与其逆变换IFFT类似,稍微加几行代码。

求FFT要用到复数。一个简单的模板:

struct Complex // 复数
{
    double r, i;
    Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
    Complex operator +(const Complex &b) {
        return Complex(r + b.r, i + b.i);
    }
    Complex operator -(const Complex &b) {
        return Complex(r - b.r, i - b.i);
    }
    Complex operator *(const Complex &b) {
        return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
    }
};

递归实现FFT模板:来源

Complex* RecursiveFFT(Complex a[], int n)//n表示向量a的维数
{
    if(n == 1)
        return a;
    Complex wn = Complex(cos(2*PI/n), sin(2*PI/n));
    Complex w = Complex(1, 0);
    Complex* a0 = new Complex[n >> 1];
    Complex* a1 = new Complex[n >> 1];
    for(int i = 0; i < n; i++)
        if(i & 1) a1[(i - 1) >> 1] = a[i];
        else a0[i >> 1] = a[i];
    Complex *y0, *y1;
    y0 = RecursiveFFT(a0, n >> 1);
    y1 = RecursiveFFT(a1, n >> 1);
    Complex* y = new Complex[n];
    for(int k = 0; k < (n >> 1); k++)
    {
        y[k] = y0[k] + w*y1[k];
        y[k + (n >> 1)] = y0[k] - w*y1[k];
        w = w*wn;
    }
    delete a0;
    delete a1;
    delete y0;
    delete y1;
    return y;
}

非递归实现。模板:(来源忘了)

void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)  
{
    int i, j, k;
    for (i = 1, j = len / 2;i < len - 1;i++)
    {
        if (i < j)swap(y[i], y[j]);
        k = len / 2;
        while (j >= k)
        {
            j -= k;
            k /= 2;
        }
        if (j < k)j += k;
    }
}
void fft(Complex y[], int len, int on) //FFT:on=1; IFFT:on=-1
{
    change(y, len);
    for (int h = 2;h <= len;h <<= 1)
    {
        Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
        for (int j = 0;j < len;j += h)
        {
            Complex w(1, 0);
            for (int k = j;k < j + h / 2;k++)
            {
                Complex u = y[k];
                Complex t = w*y[k + h / 2];
                y[k] = u + t;
                y[k + h / 2] = u - t;
                w = w*wn;
            }
        }
    }
    if (on == -1)
        for (int i = 0;i < len;i++)
            y[i].r /= len;
}

利用FFT求卷积

普通的计算多项式乘法的计算,时间复杂度 O(n2) <script type="math/tex" id="MathJax-Element-22">O(n^2)</script>。而FFT先将多项式点值表示( O(nlogn) <script type="math/tex" id="MathJax-Element-23">O(n\log n)</script>),在 O(n) <script type="math/tex" id="MathJax-Element-24">O(n)</script>下完成对点值的乘法,再以 O(nlogn) <script type="math/tex" id="MathJax-Element-25">O(n\log n)</script>完成IFFT,重新得到系数表示。

步骤一(补0)

在两个多项式前面补0,得到两个2n次多项式,设系数向量分别为 v1 <script type="math/tex" id="MathJax-Element-26">v_1</script>和 v2 <script type="math/tex" id="MathJax-Element-27">v_2</script>。

步骤二(求值)

使用FFT计算 f1=DFT(v1) <script type="math/tex" id="MathJax-Element-28">f_1=DFT(v_1)</script>和 f2=DFT(v2) <script type="math/tex" id="MathJax-Element-29">f_2=DFT(v_2)</script>。则 f1 <script type="math/tex" id="MathJax-Element-30">f_1</script>与 f2 <script type="math/tex" id="MathJax-Element-31">f_2</script>为两个多项式在 2n <script type="math/tex" id="MathJax-Element-32">2n</script>次单位根处的取值(即点值表示)。

步骤三(乘法)

f1 <script type="math/tex" id="MathJax-Element-33">f_1</script>与 f2 <script type="math/tex" id="MathJax-Element-34">f_2</script>每一维对应相乘,得到 f <script type="math/tex" id="MathJax-Element-35">f</script>,代表对应输入多项式乘积的点值表示。

步骤四(插值)

使用IFFT计算v=IDFT(f)<script type="math/tex" id="MathJax-Element-36">v=IDFT(f)</script>,其中 v <script type="math/tex" id="MathJax-Element-37">v</script>就是乘积的系数向量。

综上

ab=IDFT2n(DFT2n(a)DFT2n(b))<script type="math/tex" id="MathJax-Element-42">a \otimes b=IDFT_{2n}(DFT_{2n}(a)·DFT_{2n}(b))</script>,即: ab=DFT12n(DFT2n(a)DFT2n(b)) <script type="math/tex" id="MathJax-Element-43">a \otimes b=DFT^{-1}_{2n}(DFT_{2n}(a)·DFT_{2n}(b))</script>

A(x1)B(x2) <script type="math/tex" id="MathJax-Element-44">A(x1)\otimes B(x2)</script>:

fft(x1, len, 1);
fft(x2, len, 1);
for (int i = 0;i < len;i++) {
    x[i] = x1[i] * x2[i];
}
fft(x, len, -1);

例题

1.2016 acm香港网络赛 A题 A+B Problem

网上的代码(当时没保留出处。。。)

#include <algorithm>
#include <cstring>
#include <string.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <cstdio>
#include <cmath>

#define LL long long
#define N 200005
#define INF 0x3ffffff

using namespace std;

const double PI = acos(-1.0);


struct Complex // 复数
{
    double r, i;
    Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
    Complex operator +(const Complex &b)
    {
        return Complex(r + b.r, i + b.i);
    }
    Complex operator -(const Complex &b)
    {
        return Complex(r - b.r, i - b.i);
    }
    Complex operator *(const Complex &b)
    {
        return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
    }
};

void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)  
{
    int i, j, k;
    for (i = 1, j = len / 2;i < len - 1;i++)
    {
        if (i < j)swap(y[i], y[j]);
        k = len / 2;
        while (j >= k)
        {
            j -= k;
            k /= 2;
        }
        if (j < k)j += k;
    }
}
void fft(Complex y[], int len, int on) //DFT和FFT
{
    change(y, len);
    for (int h = 2;h <= len;h <<= 1)
    {
        Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
        for (int j = 0;j < len;j += h)
        {
            Complex w(1, 0);
            for (int k = j;k < j + h / 2;k++)
            {
                Complex u = y[k];
                Complex t = w*y[k + h / 2];
                y[k] = u + t;
                y[k + h / 2] = u - t;
                w = w*wn;
            }
        }
    }
    if (on == -1)
        for (int i = 0;i < len;i++)
            y[i].r /= len;
}


const int M = 50000;          // a数组所有元素+M,使a[i]>=0
const int MAXN = 800040;

Complex x1[MAXN];
int a[MAXN / 4];                //原数组
long long num[MAXN];     //利用FFT得到的数组
long long tt[MAXN];        //统计数组每个元素出现个数

int main()
{
    int n = 0;                   // n表示除了0之外数组元素个数
    int tot;
    scanf("%d", &tot);
    memset(num, 0, sizeof(num));
    memset(tt, 0, sizeof(tt));

    int cnt0 = 0;             //cnt0 统计0的个数
    int aa;

    for (int i = 0;i < tot;i++)
    {
        scanf("%d", &aa);
        if (aa == 0) { cnt0++;continue; }  //先把0全删掉,最后特殊考虑0
        else a[n] = aa;
        num[a[n] + M]++;
        tt[a[n] + M]++;
        n++;
    }

    sort(a, a + n);
    int len1 = a[n - 1] + M + 1;
    int len = 1;

    while (len < 2 * len1) len <<= 1;

    for (int i = 0;i < len1;i++) {
        x1[i] = Complex(num[i], 0);
    }
    for (int i = len1;i < len;i++) {
        x1[i] = Complex(0, 0);
    }
    fft(x1, len, 1);

    for (int i = 0;i < len;i++) {
        x1[i] = x1[i] * x1[i];
    }
    fft(x1, len, -1);

    for (int i = 0;i < len;i++) {
        num[i] = (long long)(x1[i].r + 0.5);
    }

    len = 2 * (a[n - 1] + M);

    for (int i = 0;i < n;i++) //删掉ai+ai的情况
        num[a[i] + a[i] + 2 * M]--;
    /*
    for(int i = 0;i < len;i++){
    if(num[i]) cout<<i-2*M<<' '<<num[i]<<endl;
    }
    */
    long long ret = 0;

    int l = a[n - 1] + M;

    for (int i = 0;i <= l; i++)   //ai,aj,ak都不为0的情况
    {
        if (tt[i]) ret += (long long)(num[i + M] * tt[i]);
    }

    ret += (long long)(num[2 * M] * cnt0);        // ai+aj=0的情况

    if (cnt0 != 0)
    {
        if (cnt0 >= 3) {                   //ai,aj,ak都为0的情况
            long long tmp = 1;
            tmp *= (long long)(cnt0);
            tmp *= (long long)(cnt0 - 1);
            tmp *= (long long)(cnt0 - 2);
            ret += tmp;
        }
        for (int i = 0;i <= l; i++)
        {
            if (tt[i] >= 2) {             // x+0=x的情况      
                long long tmp = (long long)cnt0;
                tmp *= (long long)(tt[i]);
                tmp *= (long long)(tt[i] - 1);
                ret += tmp * 2;
            }
        }
    }

    printf("%lld\n", ret);

    return 0;
}
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐