当前位置: 代码迷 >> 综合 >> 损失函数SSIM (structural similarity index) 的PyTorch实现
  详细解决方案

损失函数SSIM (structural similarity index) 的PyTorch实现

热度:83   发布时间:2023-12-17 07:14:55.0

参考这篇文章,结合我这里的代码补充一下文章链接

def gaussian(window_size, sigma):
###window_size = 11gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])print('gauss.size():', gauss.size())### torch.Size([11])return gauss/gauss.sum()def create_window(window_size, channel):_1D_window = gaussian(window_size, 1.5).unsqueeze(1)print('_1D_window.size():', _1D_window.size())### torch.Size([11, 1])###unsqueeze(1)在维度1 增加一个维度_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)print(_1D_window.mm(_1D_window.t()).size())### torch.Size([11, 11])print('_2D_window.size()', _2D_window.size())###在得到(11,11)tensor之后,接连在第0维增加维度 ###_2D_window.size() torch.Size([1, 1, 11, 11])window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())###expand the _2D_window to all the channels(if channel == 3 the other 2 equal to the first one) ###channel 有时候可以是一个通道,但是有时间可以是2 个通道,单个通道,我们那一般颜色通道RGB,三通道,这样子的话,要把前面的那一个通道扩展到3通道###如果是只用了Y channel的话,那就只有一个通道。print('window.size()', window.size())return window

插播一下关于
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
有问题的时候去找官方文档
官方介绍

  • 简单的一个小实验
>>> inputs = torch.randn(1,4,5,5)
>>> filters = torch.randn(8,4,3,3)
>>> F.conv2d(inputs, filters, padding=1).size()
torch.Size([1, 8, 5, 5])

input – input tensor of shape (minibatch,in_channels,iH,iW)
注意input是这样(minibatch, in_channels, H, W)
第0维是batch_size的数

torch.mean()
官方文档介绍

  • 继续SSIM相关内容
def _ssim(img1, img2, window, window_size, channel, size_average = True):mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)mu1_sq = mu1.pow(2)mu2_sq = mu2.pow(2)mu1_mu2 = mu1*mu2sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq#这里利用 Var(x) = E(x^2) - E(x)^2sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sqsigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2#这里利用E(X,Y)= E(XY)- E(X)E(Y)print(sigma1_sq.size(), sigma12.size())C1 = 0.01**2C2 = 0.03**2ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))print(ssim_map.size())if size_average:print(ssim_map.mean())### ssim_map.mean()是对这个tensor里面的所有的数值求平均return 1 - ssim_map.mean()else:print(ssim_map.mean(1).mean(1).mean(1))return 1 - ssim_map.mean(1).mean(1).mean(1)class SSIM(torch.nn.Module):def __init__(self, window_size = 11, size_average = True):super(SSIM, self).__init__()self.window_size = window_sizeself.size_average = size_averageself.channel = 1self.window = create_window(window_size, self.channel)def forward(self, img1, img2):(_, channel, _, _) = img1.size()if channel == self.channel and self.window.data.type() == img1.data.type():window = self.windowelse:window = create_window(self.window_size, channel)if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)self.window = windowself.channel = channelreturn _ssim(img1, img2, window, self.window_size, channel, self.size_average)

到这里这个求解损失函数的函数里面的每一个点都清楚啦~

  相关解决方案