上一篇文章介绍了如何把 BatchNorm 和 ReLU 合并到 Conv 中,这篇文章会介绍具体的代码实现。本文相关代码都可以在 github 上找到。
Folding BN
回顾一下前文把 BN 合并到 Conv 中的公式: \[ \begin{align} y_{bn}&=\frac{\gamma}{\sqrt{\sigma_y^2+\epsilon}}(\sum_{i}^N w_i x_i + b-\mu_y)+\beta \notag \\ &=\gamma'(\sum_{i}^Nw_ix_i+b-\mu_y)+\beta \notag \\ &=\sum_{i}^N \gamma'w_ix_i+\gamma'(b-\mu_y)+\beta \tag{1} \end{align} \] 其中,\(x\) 是卷积层的输入,\(w\)、\(b\) 分别是 Conv 的参数 weight 和 bias,\(\gamma\)、\(\beta\) 是 BN 层的参数。
对于 BN 的合并,首先,我们需要熟悉 pytorch 中的 BatchNorm2d
模块。
pytorch 中的 BatchNorm2d
针对 feature map 的每一个 channel 都会计算一个均值和方差,所以公式 (1) 需要对 weight 和 bias 进行 channel wise 的计算。另外,BatchNorm2d
中有一个布尔变量 affine
,当该变量为 true 的时候,(1) 式中的 \(\gamma\) 和 \(\beta\) 就是可学习的, BatchNorm2d
会中有两个变量:weight
和 bias
,来分别存放这两个参数。而当 affine
为 false 的时候,就直接默认 \(\gamma=1\),\(\beta=0\),相当于 BN 中没有可学习的参数。默认情况下,我们都设置 affine=True
。
我们沿用之前的代码,先定义一个 QConvBNReLU
模块:
1 | class QConvBNReLU(QModule): |
这个模块会把全精度网络中的 Conv2d 和 BN 接收进来,并重新封装成量化的模块。
接着,定义合并 BN 后的 forward 流程:
1 | def forward(self, x): |
这个过程就是对 Google 论文的那张图的诠释,跟一般的卷积量化的区别就是需要先获得 BN 层的参数,再把它们 folding 到 Conv 中,最后跑正常的卷积量化流程。不过,根据论文的表述,我们还需要在 forward 的过程更新 BN 的均值、方差,这部分对应上面代码 if self.training
分支下的部分。
然后,根据公式 (1),我们可以计算出 fold BN 后,卷积层的 weight 和 bias:
1 | def fold_bn(self, mean, std): |
上面的代码直接参照公式 (1) 就可以看懂,其中 gamma_
就是公式中的 \(\gamma'\)。由于前面提到,pytorch 的 BatchNorm2d
中,\(\gamma\) 和 \(\beta\) 可能是可学习的变量,也可能默认为 1 和 0,因此根据 affine
是否为 True
分了两种情况,原理上基本类似,这里就不再赘述。
合并ReLU
前面说了,ReLU 的合并可以通过在 ReLU 之后统计 minmax,再计算 scale 和 zeropoint 的方式来实现,因此这部分代码非常简单,就是在 forward 的时候,在做完 relu 后再统计 minmax 即可,对应代码片段为:
1 | def forward(self, x): |
将 BN 和 ReLU 合并到 Conv 中,QConvBNReLU
模块本身就是一个普通的卷积了,因此量化推理的过程和之前文章的 QConv2d
一样,这里不再赘述。
实验
这里照例给出一些实验结果。
本文的实验还是在 mnist 上进行,我重新定义了一个包含 BN 的新网络:
1 | class NetBN(nn.Module): |
量化该网络的代码如下:
1 | def quantize(self, num_bits=8): |
整体的代码风格基本和之前一样,不熟悉的读者建议先阅读我之前的量化文章。
先训练一个全精度网络「相关代码在 train.py 里面」,可以得到全精度模型的准确率是 99%。
然后,我又跑了一遍后训练量化以及量化感知训练,在不同量化 bit 下的精度如下表所示「由于学习率对量化感知训练的影响非常大,这里顺便附上每个 bit 对应的学习率」:
bit | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
后训练量化 | 10% | 11% | 10% | 35% | 82% | 85% | 85% | 87% |
量化感知训练 | 10% | 19% | 59% | 91% | 92% | 94% | 94% | 95% |
lr | 0.00001 | 0.0001 | 0.02 | 0.02 | 0.02 | 0.02 | 0.02 | 0.04 |
对比之前文章的结果,加入 BN 后,后训练量化在精度上的下降更加明显,而量化感知训练依然能带来较大的精度提升。但在低 bit 情况下,由于信息损失严重,网络的优化会变的非常困难。
总结
这篇文章给出了 Folding BN 和 ReLU 的代码实现,主要是想帮助初学者加深对公式细节的理解。至此,这系列教程基本告一段落,希望能帮助小白们快速入门这一领域。
PS: 之后的文章更多的会发布在公众号上,欢迎有兴趣的读者关注我的个人公众号:AI小男孩,扫描下方的二维码即可关注