这是ICLR2018的一篇具有影响力的论文,和它的兄弟篇Spectral normalization 都是一个作者写的,对GAN的发展具有挺大的影响,极大的稳定了gan的训练,也是理论性很强的论文,有很多公式推导。这篇博客不会涉及到原理部分。我仅参照论文给出的结构图,梳理一下今天啃的硬骨头。
- Categorical Conditional BatchNorm是个啥?
- 如何将条件信息y通过projection的方式融入判别器?
上述两个问题对应了如何分别将条件信息y使用batch norm和projection的方式融合进生成器和判别器,这是不同于concat的方式融合的。
先看Categorical Conditional BatchNorm
Categorical Conditional BatchNorm
class ConditionalBatchNorm2d(nn.BatchNorm2d):"""Conditional Batch Normalization"""def __init__(self, num_features, eps=1e-05, momentum=0.1,affine=False, track_running_stats=True):super(ConditionalBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)def forward(self, input, weight, bias, **kwargs):self._check_input_dim(input)exponential_average_factor = 0.0if self.training and self.track_running_stats:self.num_batches_tracked += 1if self.momentum is None: # use cumulative moving averageexponential_average_factor = 1.0 / self.num_batches_tracked.item()else: # use exponential moving averageexponential_average_factor = self.momentumoutput = F.batch_norm(input, self.running_mean, self.running_var,self.weight, self.bias,self.training or not self.track_running_stats,exponential_average_factor, self.eps)if weight.dim() == 1:weight = weight.unsqueeze(0)if bias.dim() == 1:bias = bias.unsqueeze(0)size = output.size()weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)return weight * output + bias
这段代码的重点在于最下面的几行,之前的内容是常规操作,求均值方差,做标准化。
if weight.dim() == 1:weight = weight.unsqueeze(0)if bias.dim() == 1:bias = bias.unsqueeze(0)size = output.size()weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
weight是forward里面传来的参数,是一个向量,根据类别而定的,向量长度和特征图通道数目一样。因为向量和条件信息label有一一对应的关系,所以通道将output生成weight,能融合条件信息。注意weight是经过一系列reshape操作,才能和output相乘。另外,只在生成器中使用Categorical Conditional BatchNorm
class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,affine=False, track_running_stats=True):super(CategoricalConditionalBatchNorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)self.weights = nn.Embedding(num_classes, num_features)self.biases = nn.Embedding(num_classes, num_features)self._initialize()def _initialize(self):init.ones_(self.weights.weight.data)init.zeros_(self.biases.weight.data)def forward(self, input, c, **kwargs):weight = self.weights(c)bias = self.biases(c)return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)
从weights和biases 的定义,我们可以看出,weights其实是个矩阵,每一行对应一个类别的embedding向量,用这个向量去影响在batch中属于这个类的样本的batch norm。
projection discriminator
上图是论文中给出的gan的结构。我们注意到和之前的cgan不同,在判别器部分中,y是有个embedding的过程的,和第一部分的Categorical Conditional BatchNorm融合y的方法蛮像的,只不过Categorical Conditional BatchNorm是对特征图操作,但projection是对判别器的输出与embedding向量做内积。
def forward(self, x, y=None):h = xh = self.block1(h)h = self.block2(h)h = self.block3(h)h = self.block4(h)h = self.block5(h)h = self.activation(h)# Global poolingh = torch.sum(h, dim=(2, 3))output = self.l6(h)if y is not None:output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)return output
我们从判别器的前传中找到了对y的使用
if y is not None:output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
那么这个self.l_y定义如下:
if num_classes > 0:self.l_y = utils.spectral_norm(nn.Embedding(num_classes, num_features * 16))
就是说先从Embedding那里获取y的向量,再进行spectral_norm(一种更稳定的归一化方式),之后就做点乘运算。