文章目录
- 前言
- 引入
- Demo实现
- 总结
前言
在上一篇博文RepVGG中,介绍了RepVGG网络。RepVGG 作为一种高效的重参数化网络,通过训练时的多分支结构(3x3卷积、1x1卷积、恒等映射)和推理时的单分支合并,在精度与速度间取得了优秀平衡。然而,其在低精度(如INT8)量化后常出现显著精度损失。
本文将要介绍的QARepVGG(Make RepVGG Greater Again: A Quantization-aware Approach)的提出正是为了解决这一问题。其核心贡献在于基础的Block设计:
引入
文章做了详细的消融实验来一步一步的推理出这种结构,本文在此不多做赘述。只大概提一下:RepVGG其实是由三个单元构成:权重、BN和ReLU。卷积操作一般不会影响权重值的改变,基本服从0~1分布;而根据BN层的公式,会出现一个乘法项,导致方差可能发生改变;另外,如果输入的数值范围很大,经过ReLU也会产生大的方差项,导致量化困难。
因此,QARepVGG去掉了BN层,并在三个分支后新加了一个BN层来将分布改成一个量化友好的分布。
当然,建议读者阅读原论文,好多实验的设计跟分析很透彻。
Demo实现
本文旨在复现一个QARepVGG Block,读者可一键运行:
import torch
import torch.nn as nn
class QARepVGGBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
assert in_channels == out_channels, "输入输出通道必须相同!"
# 分支1:3x3卷积 + BN
self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn3x3 = nn.BatchNorm2d(out_channels)
# 分支2:1x1卷积(无BN)
self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
# 分支3:恒等映射(无BN)
self.identity = nn.Identity() # 直接传递输入
# 合并后的BN层
self.final_bn = nn.BatchNorm2d(out_channels)
# 初始化权重(关键!)
self._init_weights()
def _init_weights(self):
"""显式初始化权重"""
nn.init.kaiming_normal_(self.conv3x3.weight, mode='fan_out', nonlinearity='relu')
nn.init.zeros_(self.conv1x1.weight) # 初始化为零,与恒等映射互补
def forward(self, x):
# 分支1:3x3卷积 + BN
branch3x3 = self.bn3x3(self.conv3x3(x))
# 分支2:1x1卷积
branch1x1 = self.conv1x1(x)
# 分支3:恒等映射
branch_id = self.identity(x)
# 合并后通过最终BN
out = self.final_bn(branch3x3 + branch1x1 + branch_id)
return out
def reparameterize(self):
"""将多分支合并为单一3x3卷积,并融合BN参数"""
# 1. 将各分支转换为等效3x3卷积
# 分支1:3x3卷积 + BN3x3
kernel3x3, bias3x3 = self._fuse_conv_bn(self.conv3x3, self.bn3x3)
# 分支2:1x1卷积(无BN),填充为3x3
kernel1x1 = self._pad_1x1_to_3x3(self.conv1x1.weight)
bias1x1 = torch.zeros_like(bias3x3) # 无偏置
# 分支3:恒等映射(视为1x1单位矩阵卷积,填充为3x3)
identity_kernel = torch.eye(self.conv3x3.in_channels, device=self.conv3x3.weight.device)
identity_kernel = identity_kernel.view(self.conv3x3.in_channels, self.conv3x3.in_channels, 1, 1)
kernel_id = self._pad_1x1_to_3x3(identity_kernel)
bias_id = torch.zeros_like(bias3x3)
# 2. 合并所有分支的权重和偏置
merged_kernel = kernel3x3 + kernel1x1 + kernel_id
merged_bias = bias3x3 + bias1x1 + bias_id
# 3. 融合最终BN层参数
scale = self.final_bn.weight / (self.final_bn.running_var + self.final_bn.eps).sqrt()
merged_kernel = merged_kernel * scale.view(-1, 1, 1, 1)
merged_bias = scale * (merged_bias - self.final_bn.running_mean) + self.final_bn.bias
# 4. 构建合并后的卷积层
merged_conv = nn.Conv2d(
self.conv3x3.in_channels,
self.conv3x3.out_channels,
kernel_size=3,
padding=1,
bias=True
)
merged_conv.weight.data = merged_kernel
merged_conv.bias.data = merged_bias
return merged_conv
def _fuse_conv_bn(self, conv, bn):
"""融合卷积和BN的权重与偏置"""
kernel = conv.weight
running_mean = bn.running_mean
running_var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
std = (running_var + eps).sqrt()
scale_factor = gamma / std
fused_kernel = kernel * scale_factor.view(-1, 1, 1, 1)
fused_bias = beta - running_mean * scale_factor
return fused_kernel, fused_bias
def _pad_1x1_to_3x3(self, kernel):
"""将1x1卷积核填充为3x3(中心为原权重,其余为0)"""
if kernel.size(-1) == 1:
padded = torch.zeros(kernel.size(0), kernel.size(1), 3, 3, device=kernel.device)
padded[:, :, 1, 1] = kernel.squeeze()
return padded
return kernel
def test_qarepvgg():
torch.manual_seed(42)
# 输入数据(小方差加速BN收敛)
x = torch.randn(2, 3, 4, 4) * 0.1
# 初始化模块
block = QARepVGGBlock(3, 3)
# 训练模式:更新BN统计量
block.train()
for _ in range(100): # 充分训练
y = block(x)
y.sum().backward() # 伪反向传播
# 推理模式:合并权重
block.eval()
with torch.no_grad():
# 原始输出
orig_out = block(x)
# 合并后的卷积
merged_conv = block.reparameterize()
merged_out = merged_conv(x)
# 打印关键数据
print("out:", orig_out.mean().item())
print("merge:", merged_out.mean().item())
print("diff:", torch.abs(orig_out - merged_out).max().item())
# 验证一致性(容差1e-6)
assert torch.allclose(orig_out, merged_out, atol=1e-6), f"合并失败!最大差值:{torch.abs(orig_out - merged_out).max().item()}"
print("✅ 测试通过!")
test_qarepvgg()
总结
欢迎留言交流讨论。