代码
自定义函数需要继承 torch.autograd.Function
类,并实现两个静态方法 forward
和 backward
import torch
# 自定义函数
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
"""
该自定义函数作用是将输入乘以2
使用以下方法将函数输入保存下来给backward中使用
ctx.save_for_backward(input)
"""
return input * 2
@staticmethod
def backward(ctx, grad_output):
"""
返回一个固定的梯度
这里结果必须乘上grad_output,下文有解释
使用以下方法获取forward过程中的输入
input, = ctx.saved_tensors
"""
return torch.tensor(3.) * grad_output
# 使用
x = torch.tensor(1., requires_grad=True)
y = MyFunc().apply(x) # 注意使用apply函数调用
y.backward()
print(f'grad: {x.grad}')
# 输出
# grad: 3.0
解释
两个函数的作用
forward 函数用于forward过程中,backward 函数用于backward过程中,两者并没有直接关系,所以上述代码其实是 \(y=w*x \quad where \quad x=1, w=2\) ,但 \(\frac {dy}{dw}=3\)(本来应该等于2的)。这也是为什么backward中要想得到函数的输入就需要在forward中先保存
backward的grad_output参数
有如下函数:
y = f1(x)
z = f2(y)
则 \(\frac {dz}{dx} = \frac {dz}{dy} * \frac {dy}{dx} \) 这就是链式求导法则。假设自定义函数处于整个计算过程中的中间部分(例如上式中的f1),backward中的grad_output参数就是该函数之前所有函数的求导结果(f2的求导结果),也就是说,假如你按照我上面的代码那样给一个定值,则其前面计算的导数都是无效的,但我们知道,求导在模型参数更新中主要的作用是计算参数增量的方向,如果只计算整个模型的部分导数可能导致该增量符号错误,进而影响模型训练,所以需要在结果上乘以该参数。
forward的input参数
一个容易混淆的概念就是这个input指的是什么,很容易就和 torch.nn.Module 中的 forward(input) 函数对应起来,但其实他们是完全不同的两种input。
torch.nn.Module 中forward函数的input指的是数据集或者待预测的数据张量(即 y=w*x 中的x),其 requires_grad=False
,而 torch.autograd.Function 中forward函数的input指的是待学习参数(即 y=w*x 中的w),其 requires_grad=True
当然,你也可以传递 requires_grad=False
的参数给它,但在backward中需要return None,例如你可以看下文中的例子
backward的返回值
forward有多个参数的情况下,backward也应该有一一对应的梯度返回值,那么这些梯度返回值应该是什么shape的呢?
答曰:梯度返回值shape的后几位与input的shape应该相同。例如input的shape为[3, 64, 64],则梯度返回值的shape应该为[…, 3, 64, 64](前面多余的维度会自动进行求和),并且在模型backward的时候需要指定input的维度(若为标量则不用指定),其实就是相当于将多个待训练参数组合成一个矩阵的形式了,所以需要告诉backward这个input指的是多个参数集合
...
loss = model(x)
loss.backward(torch.ones_like(input.shape))
...
多参数举例
我需要实现一个带有学习率的模型,它能够实现以下功能:
# 该函数将输入input进行截断,低于low_threshold或高于high_threshold的值都会使用该阈值进行代替,并且使用指定的scale对整体数据进行缩放
# 其中,low_threshold, high_threshold 是需要学习的参数,而scale是指定的超参数
clamp(input, low_threshold, high_threshold, scale)
以下为实现代码
class ClampModel(torch.nn.Module):
def __init__(self, scale):
super(ClampModel, self).__init__()
self.scale = scale
self.clamp_low_threshold = torch.nn.Parameter(data=torch.tensor(-150.), requires_grad=True)
self.clamp_high_threshold = torch.nn.Parameter(data=torch.tensor(150.), requires_grad=True)
self.clamp = Clamp().apply
def forward(self, x):
x = self.clamp(x, self.clamp_low_threshold, self.clamp_high_threshold, self.scale)
return x
class Clamp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, low_threshold, high_threshold, scale):
return torch.clamp(x, low_threshold, high_threshold) * scale
@staticmethod
def backward(ctx, grad_output):
return grad_output * scale, grad_output * torch.tensor(0.1), grad_output * torch.tensor(0.1), None # 梯度返回值需和forward参数顺序对应,非学习参数需要返回None
# 使用方法
m = ClampModel(scale=2)
output = m(input)
output .backward()
...
参考
官网:https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html