跳至主要內容

KAN: Kolmogorov–Arnold Networks

Xenny约 4445 字大约 19 分钟KANKAN

KAN: Kolmogorov–Arnold Networks

  • arxiv链接:arxiv.org/abs/2404.19756open in new window

  • 本文提出了基于KA(Kolmogorov Arnold)定理的K-A网络并与MLP进行比较。与MLP最大的不同是MLP在节点(神经元)上具有固定的激活函数,而KAN在边(权重)上具有可学习的激活函数。即KAN天生就不是线性权重——每个权重参数都被参数化为样条曲线的单变量函数所替代。

    本文发现这个改动使得KAN在准确性和可解释性方面要优于MLP。具体而言在数据拟合和PDE求解,更小的KAN便可以与一个大MLP取得相同的准确性。同时KAN可以直观的可视化,便于人工交互,即更好的可解释性。总而言之本文认为KAN非常有理由成为MLP的替代品。

知识补充

贝塞尔曲线

样条函数(spline)函数

  • 分段光滑,各段交接处也一定光滑性的多段函数。用于插值或者拟合。

    和多项式函数的不同便是拟合(或者插值)时是分段操作的,即分为nn个段,每个段单独拟合一条曲线,曲线与曲线(段与段)交接点在额外处理保证一定光滑性。

  • 后续本文中的B-样条函数即贝塞尔曲线的一般形式,参考速通:贝塞尔曲线和B样条open in new window

    简单来说贝塞尔曲线就是基于分段的全局拟合,而B-样条中增加了支撑集的概念使得不会因为局部改动影响全部。

Introduction

图1.1 MLP vs. KAN
图1.1 MLP vs. KAN
  • 如图1.1可见MLP与KAN的不同,对于KAN而言不存在线性权重矩阵,每个权重参数都被参数化为可学习一维(单变量)样条函数替代。在KAN中节点只需要进行加权求和即可,而不需要在节点处使用激活函数。

    这里人们可能会担心权重参数变成样条函数后成本变高,但是本文提出因为KAN相比MLP只需要更小的模型,例如对于PDE求解,2层宽度10的KAN比4层宽度100的MLP精确100倍,参数效率高100倍。(所以实际成本可能不会增加反而有更多的好处)

    在更早之前也有很多关于基于KA的神经网络研究工作,不过都停留在2层宽度2n+12n+1表示,而且没有用到现代技术(例如BP)训练网络。本文工作将KA表示推广到任意深度和宽度。

  • 本文也承认KAN本质上是样条函数和MLP的组合,利用各自的优势避开各自的劣势。对于样条曲线而言其适用于低维函数,易于局部调整,但存在维数诅咒(COD)问题。对于MLP而言基于特征学习适用与高维,但在低维中不如样条函数精确。所以KAN对外(整体)上为MLP,内部具有样条函数,这种结构使得KAN不仅可以学习特征还可以准确优化这些学习的特征。例如对于高维函数

    f(x1,,xN)=exp(1Ni=1Nsin2(xi))(1.1) f(x_1, \dots, x_N) = \exp\left(\frac{1}{N}\sum_{i=1}^N{\sin^2(x_i)}\right)\tag{1.1}

    由于COD的存在样条函数在NN较大时将失效,而MLP则可以学习广义加性结构(generalized additive structure),不过对于使用ReLU来近似指数函数和正弦函数效率非常低(也就是为啥MLP要那么多参数)。所以KAN在这类情况中远远优于MLP。

图2.1 文章结构
图2.1 文章结构
  • 本文的结构如图2.1所示。第二节中将介绍KAN结构及数学基础。第三节展示数据拟合和PDE求解方面和MLP的对比。第四节展示KAN良好的可解释性。第五节中总结相关工作。第六节讨论KAN的影响和未来展望。

2 Kolmogorov–Arnold Networks (KAN)

2.1 Kolmogorov-Arnold Representation theorem

  • Vladimir ArnoldAndrey Kolmogorov提出如果ff是有界域上的多元连续函数,则ff可以写成单变量连续函数和加法运算的有限组合。例如对于光滑函数f:[0,1]nRf:[0,1]^n\rightarrow\mathbb{R}

    f(x)=f(x1,,xn)=q=12n+1Φq(p=1nϕq,p(xp))(2.1) f(\mathbf{x}) = f(x_1, \cdots, x_n) = \sum_{q=1}^{2n+1}\Phi_q\left(\sum_{p=1}^n{\phi_{q,p}(x_p)}\right)\tag{2.1}

    其中ϕq,p:[0,1]R\phi_{q,p}:[0, 1]\rightarrow\mathbb{R}Φq:RR\Phi_q:\mathbb{R}\rightarrow\mathbb{R}。这代表只有加法是一个多变量函数(f(x,y)=x+yf(x,y) = x+y),因为其他所有函数都可以用单变量函数和加法进行表示。对于ML而言,这代表学习高维函数可能可以归约为学习多项式数量的一维函数。不过实际上是不行的,因为这些一维函数可能是非光滑的,甚至是分形的,也即在实践中可能是不可学习的。所以之前的工作都认为K-A表示在ML中基本没啥用(理论有用,实际无用)的状态。

  • 本文乐观的认为K-A还是有用的,基于此做了两部分工作,一是将网络推广到任意宽度和深度,二是研究一般性任务中KAN的具体表现。

2.2 KAN architecture

图2.2 左:网络架构。右:参数化B样条函数以及两种粒度
图2.2 左:网络架构。右:参数化B样条函数以及两种粒度
  • 对于一组监督学习任务{xi,yi}\{\mathbf{x}_i, y_i\},我们想要找到函数ff使得yif(xi)y_i\approx f(\mathbf{x}_i)对所有数据对成立。由式2.1可知如果能找到合适的单变量函数ϕq,p,Φq\phi_{q,p},\Phi_q便可以得到ff。由此可以设计一个参数化方程的神经网络,即将每个要学习的单变量函数参数化为具有局部B-样条基的可学习系数B-样条函数(图2.2右侧)。这便是KAN的原型,计算图由式2.1指定,最终如图1.1(b)所示(n=2)的一个两层神经网络。

    本文的突破性工作便是结合MLP的多层结构,定义了KAN层——具有ninn_{in}维输入及noutn_{out}输出的KAN层由一维函数矩阵表示为

    Φ={ϕq,p}, p=1,2,,nin, q=1,2,,nout(2.2) \Phi = \{\phi_{q,p}\},\ p = 1,2,\dots,n_{in},\ q = 1,2,\dots,n_{out}\tag{2.2}

    其中ϕq,p\phi_{q,p}具有可训练参数,后续即可和MLP相同的方式(堆叠多个KAN层)来获得更深层的神经网络。

  • 如图2.2(左)示,KAN形状可以由一个整数数组表示,

    [n0,n1,,nL](2.3) [n_0,n_1,\cdots,n_L]\tag{2.3}

    其中nin_i表示计算图第ii层的节点数。用(l,i)(l,i)-神经元代表第ll层的第ii个神经元,用xl,ix_{l,i}代表该神经元的激活值。在第ll层与l+1l+1层之间存在nlnl+1n_ln_{l+1}个激活函数,连接(l,j)(l,j)(l+1,i)(l+1,i)的激活函数表示为

    ϕl,i,j, l=0,,L1,i=1,,nl+1, j=1,,nl(2.4) \phi_{l,i,j},\ l = 0,\cdots,L-1, i = 1,\dots,n_{l+1},\ j=1,\dots,n_l\tag{2.4}

    同时ϕl,i,j\phi_{l,i,j}预激活(激活前的)值表示为xl,ix_{l,i},后激活(激活后的)值表示为x^l,i,jϕl,i,j(xl,i)\hat{x}_{l,i,j} \equiv \phi_{l,i,j}(x_{l,i})(l+1,j)(l+1,j)-神经元的激活值为所有传入值的和

    xl+1,j=i+1nlx^l,i,j=i+1nlϕl,i,j(xl,i), j=1,,nl+1(2.5) x_{l+1,j} = \sum_{i+1}^{n_l}\hat{x}_{l,i,j} = \sum_{i+1}^{n_l}{\phi_{l,i,j}(x_{l,i})},\ j=1,\dots,n_{l+1}\tag{2.5}

    换成矩阵形式即

    xl+1=(ϕl,1,1()ϕl,1,2()ϕl,1,n1()ϕl,2,1()ϕl,2,2()ϕl,2,n1()ϕl,nl+1,1()ϕl,nl+1,2()ϕl,nl+1,n1())Φlxl(2.6) \mathbf{x}_{l+1} = \underbrace{\begin{pmatrix} \phi_{l,1,1}(\cdot)&\phi_{l,1,2}(\cdot)&\cdots&\phi_{l,1,n_1}(\cdot)\\ \phi_{l,2,1}(\cdot)&\phi_{l,2,2}(\cdot)&\cdots&\phi_{l,2,n_1}(\cdot)\\ \vdots&\vdots&&\vdots\\ \phi_{l,n_{l+1},1}(\cdot)&\phi_{l,n_{l+1},2}(\cdot)&\cdots&\phi_{l,n_{l+1},n_1}(\cdot)\\ \end{pmatrix}}_{\Phi_l}\mathbf{x}_l\tag{2.6}

    Φi\Phi_{i}表示第ll个KAN层的函数矩阵。常规KAN由LL层组成:给定输入向量x0Rn0\mathbf{x}_0\in\mathbb{R}^{n_0},KAN的输出为

    KAN(x)=(ΦL1ΦL2ΦL1Φ1Φ0)x(2.7) \mathrm{KAN}(\mathbf{x}) = (\Phi_{L-1}\circ\Phi_{L-2}\Phi_{L-1}\circ\cdots\circ\Phi_{1}\circ\Phi_{0})\mathbf{x}\tag{2.7}

Implementation details

  • 虽然KAN层方程很简单,但是还需要一些优化技巧:
  1. 残差激活函数(Residual activation functions):选择一个基函数b(x)b(x)(类似残差连接),满足激活函数ϕ(x)\phi(x)是基函数b(x)b(x)和样条函数之和

    ϕ(x)=ω(b(x)+spline(x))(2.8) \phi(x) = \omega(b(x)+\mathrm{spline}(x))\tag{2.8}

    其中

    b(x)=silu(x)=x/(1+ex)(2.9) b(x) = \mathrm{silu}(x) = x/(1+e^{-x})\tag{2.9}

    在大多数情况下样条函数spline(x)\mathrm{spline}(x)被参数化为B-样条曲线的线性组合,满足

    spline(x)=iciBi(x)(2.10) \mathrm{spline}(x) = \sum_i{c_iB_i(x)}\tag{2.10}

    其中cic_i是训练参数,ω\omega为权重因子。

  2. 初始化

    每个激活函数初始化为spline(x)0\mathrm{spline}(x)\approx 0ω\omega由Xavier初始化确定。

  3. 样条网格更新

    样条曲线定义在有界区域上但激活值在训练过程中可能会从固定区域演化出来,所以这里选择根据输入激活值动态更新每个网格。

Parameter count

  • 对于一个深度为LL,每层宽度为n0=n1==nL=Nn_0=n_1=\cdots=n_L = N,每个3次样条包含GG段(G+1G+1个格点)。

    则一共有O(N2L(G+k))O(N2LG)O(N^2L(G+k))\sim O(N^2LG)个参数。同尺寸的MLP包含O(N2L)O(N^2L)个参数,但KAN大部分情况需要的尺寸比MLP更小且泛化能力更好。

2.3 KAN’s Approximation Abilities and Scaling Laws

  • 在式2.1中,2层深度2n+12n+1宽度的表示可能是非光滑的,而更深层表示可能会带来更平滑的激活值。例如一个4变量函数

    f(x1,x2,x3,x4)=exp(sin(x12+x22)+sin(x32+x42))(2.11) f(x_1,x_2,x_3,x_4) = \exp(\sin(x_1^2+x_2^2) + \sin(x_3^2 + x_4^2))\tag{2.11}

    可以由3层KAN[4,2,1,1][4,2,1,1]平滑表示,但是可能无法用具有平滑激活函数的2层网络来表示。为了便于近似分析,本文仍假设激活函数是平滑的,且允许表示层为任意宽度和深度。

  • 定理2.1(近似定理KAT):设x=(x1,x2,,xn)\mathbf{x} = (x_1,x_2,\cdots,x_n),函数f(x)f(x)代表一个表示层(类似式2.7)

    f=(ΦL1ΦL2Φ1Φ0)x(2.12) f = (\Phi_{L-1}\circ\Phi_{L-2}\circ\cdots\circ\Phi_1\circ\Phi_0)\mathbf{x}\tag{2.12}

    其中Φl,i,j\Phi_{l,i,j}(k+1)(k+1)阶可微。则存在依赖于ff的一个常数CC,网格大小为GG,有如下近似界:存在kk阶B-样条函数Φl,i,jG\Phi_{l,i,j}^G满足对于任意0mk0\le m \le k,有

    f(ΦL1GΦL2GΦ1GΦ0G)xCmCGk1+m(2.13) \lVert f - (\Phi_{L-1}^G\circ\Phi_{L-2}^G\circ\cdots\circ\Phi_1^G\circ\Phi_0^G)\mathbf{x} \rVert_{C^m} \le CG^{-k-1+m}\tag{2.13}

    这里使用CmC^m范数衡量mm阶导数的大小,有

    gCm=maxβmsupx[0,1]nDβg(x) \lVert g\rVert_{C^m} = \max\limits_{|\beta|\le m}\sup\limits_{x\in[0,1]^n}\lvert D^\beta g(x)\rvert

  • 也就是说可以用有限网格大小的KAN进行拟合,而且误差与维数无关,从而摆脱COD问题。对于CC是否和维度相关本文留给未来工作。

神经缩放定律

  • 神经缩放定律是指测试损失随着模型参数增加而减少的现象(就是越大越好)。例如lNα\mathcal{l}\propto N^{-\alpha},其中l\mathcal{l}为测试RMSE,NN为参数数目,α\alpha为缩放指数。

    在KAN中假设目标存在平滑的K-A表示,将高维函数分解为几个一维函数,则α=k+1\alpha = k+1kk为阶数),后续在节3.1证明了α=4\alpha=4也是最大最好的缩放指数。

Comparison between KAT and UAT

2.4 For accuracy: Grid Extension

  • 理论上样条可以精确拟合目标函数,因为网格可以任意精细度,KAN也正是由此特性所以比MLP更好。而且对于KAN而言,可以先训练一个小网络,再用更精细的网格扩展到更多参数的KAN而无需重新训练。

    网格扩展(如图2.2右所示)的目的便是将新的粒度样条曲线拟合到旧的粒度样条曲线中。例如我们想用kk阶样条拟合有界区域[a,b][a,b]中的一维函数ff,具有G1G_1段的网格点为{t0=a,t1,t2,,tG1=b}\{t_0 = a, t_1,t_2,\cdots, t_{G_1} = b\},将其扩充为{tk,,t1,t0,,tG1,tG1+1,,tG1+k}\{t_{-k},\cdots,t_{-1},t_0,\cdots,t_{G_1},t_{G_1+1},\cdots,t_{G_1+k}\},其中包含G1+kG_1+k个B-样条基函数,且第ii条B-样条Bi(x)B_i(x)仅在[tk+i,ti+1], i=0,,G1+k1[t_{-k+i},t_{i+1}],\ i=0,\dots,G_1+k-1上非0.

    然后对于旧网格上的ff将右这些B-样条基函数线性表示,fcoarse(x)=i=0G1+k1ciBi(x)f_{\mathrm{coarse}}(x) = \sum_{i=0}^{G_1+k-1}{c_iB_i(x)}。给定一个具有G2G_2段的精细网格,新网格上的ff相应地为ffine(x)=i=0G1+k1ciBi(x)f_{\mathrm{fine}}(x) = \sum_{i=0}^{G_1+k-1}{c_i^{'}B_i^{'}(x)}。其中参数cjc_j^{'}将通过最小化ffine(x)f_{\mathrm{fine}}(x)fcoarse(x)f_{\mathrm{coarse}}(x)之间的距离(最小二乘法实现)进行初始化。

    图2.3 网格扩展
    图2.3 网格扩展
  • 示例

    这里将使用不同的KAN来拟合f(x,y)=exp(sin(πx)+y2)f(x,y) = \exp(\sin(\pi x) + y^2),结果如图2.3所示。可以看到随着网格增加训练和测试损失都减小,但当达到一定大小后,测试损失不降反增(过拟合)。而最后当网格达到1000时可能时由于优化算法由于损失函数地形不佳而停止工作。

    这里文中预计对于[2,5,1][2,5,1]KAN(包含15G15G个参数)将在G=1000/1567G=1000/15\approx 67(1000为数据样本)处取插值阈值(图2.3中红色虚线)。

  • Small KANs generalize better.

    在图2.3中可间,[2,1,1][2,1,1]KAN比[2,5,1][2,5,1]KAN有着更低的测试loss,而且其插值阈值更高(即泛化能力更强),这也是KAN中的关键问题,如何确定最小KAN形状?在本文节2.5中将给出一种通过正则化和修建来自动发现最小KAN架构的方法。

  • Scaling laws

    这里探究测试loss和网格参数数量之间的关系,对于[2,1,1][2,1,1]KAN其尺度大致满足testRMSEG3\mathrm{test} \mathrm{RMSE}\propto G^{-3}。根据定理2.1可知期望为testRMSEG4\mathrm{test} \mathrm{RMSE}\propto G^{-4}。文中的解释时可能因为样本误差不均匀,所以图2.3中还放了误差中值得变化发现接近G4G^{-4}的尺度。

  • External vs Internal degrees of freedom.

    KAN中的一个新概念是外部自由度和内部自由度之间的区别。节点连接方式的计算图表示外部自由度("dofs"),激活函数内部的网格表示内部自由度。外部dofs负责学习多个变量组成的结构,内部dofs复杂学习单变量函数。

2.5 For Interpretability: Simplifying KANs and Making them interactive

  • 这里的核心思想是训练一个足够大的KAN,再利用稀疏性正则化和修建进行训练。以及介绍了人工交互技术讨论KAN的可解释性。

简化技术

  1. 稀疏化

    类似MLP引入正则项进行正则化的思想,不同的是KAN中需要两次修改。

    1. KAN没有线性“权重”(被可学习函数取代),所以需要重新定义L1范数。
    2. 本文发现L1不足以稀疏化KAN,所以这里选用增加熵进行正则化。

    本文定义对于NpN_p个输入的激活函数的L1范数为

    ϕ11Nps=1Npϕ(x(s))(2.14) \lvert \phi\rvert_1 \equiv \frac{1}{N_p}\sum_{s=1}^{N_p}\left\lvert\phi(x^{(s)})\right\rvert\tag{2.14}

    对于具有NinN_{\mathrm{in}}个输入NoutN_{\mathrm{out}}个输出的KAN层中的Φ\Phi,将Φ\Phi的L1范数定义为所有激活函数的L1范数和,即

    Φ1i=1Ninj=1Noutϕi,j1(2.15) \lvert \Phi\rvert_1 \equiv \sum_{i=1}^{N_{\mathrm{in}}}\sum_{j=1}^{N_{\mathrm{out}}}\lvert\phi_{i,j}\rvert_1\tag{2.15}

    最后定义Φ\Phi的熵为

    S(Φ)i=1Ninj=1Noutϕi,j1Φ1log(ϕi,j1Φ1)(2.16) S(\Phi) \equiv -\sum_{i=1}^{N_{\mathrm{in}}}\sum_{j=1}^{N_{\mathrm{out}}}\frac{\lvert\phi_{i,j}\rvert_1}{\lvert\Phi\rvert_1}\log\left(\frac{\lvert\phi_{i,j}\rvert_1}{\lvert\Phi\rvert_1}\right)\tag{2.16}

    最终训练lossltotal\mathcal{l}_{total}为预测损失lpred\mathcal{l}_{pred}加L1和所有KAN层的熵正则项。

    ltotal=lpred+λ(μ1l=0L1Φl1+μ2l=0L1S(Φl))(2.17) \mathcal{l}_{total} = \mathcal{l}_{pred} + \lambda\left(\mu_1\sum_{l=0}^{L-1}\lvert \Phi_l\rvert_1 + \mu_2\sum_{l=0}^{L-1}{S(\Phi_l)}\right)\tag{2.17}

    其中μi\mu_i为正数通常μ1=μ2=1\mu_1 = \mu_2 = 1λ\lambda为超参数。

  2. 可视化

  3. 修剪

    用稀疏化惩罚训练后,可以在节点上(而不是边上)进一步稀疏化KAN,对于每个节点定义传入传出的分数为

    Il,i=maxk(ϕl1,k,i)1,  Ol,i=maxj(ϕl+1,j,i1)(2.18) I_{l,i} = \max\limits_{k}(\lvert\phi_{l-1,k,i})\rvert_1,\ \ O_{l,i} = \max\limits_{j}(\lvert\phi_{l+1,j,i}\rvert_1)\tag{2.18}

    如果传入传出分数大于都阈值则认为节点很重要,默认情况选择超参数θ=102\theta = 10^{-2},小于阈值的节点将被修剪(删除)。

  4. 符号化(Symbolification)

    当怀疑某些激活函数是符号时(指有名字,例如cos,log),KAN提供了一个接口将它们设置为指定的符号,例如fix_symbolic(l, i, j, f)代表将(l,i,j)(l, i, j)激活函数设置为ff

图2.4 简化技术示例
图2.4 简化技术示例

个人总结

  1. KAN与MLP的不同便是KAN中的激活函数从固定变成 -> 变动。

    具体变动的方法是使用GG段的B-样条函数进行拟合。

    通过多层网络来将高维任务化简为多项式个一维任务。且由KAT解决了样条曲线COD问题(即误差足够小)。

  2. 网格扩展技术

    使得训练时能够动态变换网格大小改变精度而无需重新设置参数从头训练。

  3. 简化技术

    定义了新的正则项以及修剪技术简化网络。

  4. 可解释性

    传统MLP是值的变化,所以遗忘和可解释性都是问题,个人认为本质还是数据维度的问题,单个值的信息太少了,纯靠堆参数。而KAN将学习参数变成一个函数,通过函数的变化更好的体现学习过程以及函数是可以再使用的。