跳至主要內容

四、PyTorch—自定义反向传播

原创Xenny约 726 字大约 3 分钟机器学习机器学习PyTorch

四、PyTorch—自定义反向传播

  • 神经网络 —— 损失函数 & 反向传播open in new window中已经介绍了反向传播的基本概念。

    反向传播的本质就是NN条偏导链,通过最后的loss值往前逐层计算梯度,更新权重。

    PyTorch中,使用loss.backward()便是执行反向传播的过程,不过本文的重点主要是了解如何自定义反向传播过程。一般来说,有两种情况我们需要用到自定义反向传播:

    1. 引入了不可微的函数,此时需要自定义求导方式;
    2. 希望在偏导链中引入新的约束(类似损失函数的惩罚项)。

反向传导

  • 我们考虑如下的计算过程

    c = a + b
    e = c + d
    f = b + c
    o = e + f
    

    现在我们要计算a的梯度。显然,我们需要先计算c的梯度,而c的梯度则需要先计算ef,同理我们会得到一张反向传导图,那么此时我们如何

    在PyTorch中,反向传导图是通过function节点进行构建,并且维护一个ReadyQueue优先队列,通过比较节点的依赖序列来决定谁先进行计算。

自定义backward

  • 我们可以通过自定义模型的backward函数来自定义求导过程。PyTorch在官网中给了一个线性回归的例子,对于我们想要定义如下函数

    y=xω+b(1) \mathbf{y} = \mathbf{x}\omega + \mathbf{b}\tag{1}

    我们设自定义求导公式为

    yx=ωyω=xyb=1(2) \begin{aligned} \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \omega\\ \frac{\partial \mathbf{y}}{\partial \omega} = \mathbf{x}\\ \frac{\partial \mathbf{y}}{\partial \mathbf{b}} = 1 \end{aligned}\tag{2}

    此时我们可以通过继承Function节点来自定义求导。

    class CustomFunction(Function):
        @staticmethod
        def forward(ctx, input, weight, bias=None):
            ctx.save_for_backward(input, weight, bias)
            ouput = input.mm(weight.t())
            if bias is not None:
                output += bias.unsqueeze(0).expand_as(output)
            
            return output
        
        @staticmethod
        def backward(ctx, grad_output):
            input, weight, bias = ctx.saved_tensors
            grad_input = grad_weight = grad_bias = None
    
            if ctx.needs_input_grad[0]:
                grad_input = grad_output.mm(weight)
            if ctx.needs_input_grad[1]:
                grad_weight = grad_output.t().mm(input)
            if bias is not None and ctx.needs_input_grad[2]:
                grad_bias = grad_output.sum(0)
            
            return grad_input, grad_weight, grad_bias
    

    从代码中可以看到函数分为两个部分,forward定义了正向过程,即函数计算,backward定义了逆向过程,即反向传播。

    forward中我们使用save_for_backward保存当前张量。在backward中可以从saved_tensors读取正向过程中保存的张量。grad_output中包含了正向输出对于最终损失函数的梯度,即output的梯度,此时我们可以通过自定义的求导法则来计算每个参数的梯度并返回,其中needs_input_grad代表参数是否需要梯度,顺序为forward的形参顺序。

    在网络中,我们使用CustomFunction.apply(input, weight, bias)来调用该函数,backward函数会在整个网络反向传播时自动调用。