torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。例如:这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

PyTorch的Optimizer训练工具的实现

PyTorch的Optimizer训练工具的实现