当前位置: 代码迷 >> 综合 >> PyTorch提取中间层的特征(Resnet)
  详细解决方案

PyTorch提取中间层的特征(Resnet)

热度:60   发布时间:2023-10-11 06:56:44.0

     特征提取在深度学习的训练中是经常要做的事情,之前的一篇blog有写到使用pytorch提取Vgg、Resnet、Densenet三种模型下的特征,这里所述的是提取全连接层(FC层)的特征,详情可见:https://blog.csdn.net/qq_34611579/article/details/84330968。

     在本文中,主要是介绍提取中间层的特征,对于特征的提取,可以先把模型的结构输出,不同的模型结构是不一样的;下面拿resnet作为示例;由于pytorch模型很多用到nn.sequential,所以对各层的特征提取要自己去修改forward函数。

# 中间层特征提取
class FeatureExtractor(nn.Module):def __init__(self, submodule, extracted_layers):super(FeatureExtractor, self).__init__()self.submodule = submoduleself.extracted_layers = extracted_layers# 自己修改forward函数def forward(self, x):outputs = []for name, module in self.submodule._modules.items():if name is "fc": x = x.view(x.size(0), -1)x = module(x)if name in self.extracted_layers:outputs.append(x)return outputs

    这里可以看到,我们自定义了forward函数,由此可以选择在哪一层提取特征。

extract_list = ["conv1", "maxpool", "layer1", "avgpool", "fc"]
img_path = "./1_00001.jpg"
saved_path = "./1_00001.txt"
resnet = models.resnet50(pretrained=True)
# print(resnet) 可以打印看模型结构transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor()]
)img = Image.open(img_path)
img = transform(img)x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)if use_gpu:x = x.cuda()resnet = resnet.cuda()extract_result = FeatureExtractor(resnet, extract_list)
print(extract_result(x)[4])  # [0]:conv1  [1]:maxpool  [2]:layer1  [3]:avgpool  [4]:fc

  对于模型的结构我们可以打印出来看看,这里是采用的Resnet,此时forward函数也是针对该模型进行了修改;若比如其他的VGG,DenseNet也可以用类似的方法进行修改。

参考这里所说:https://blog.csdn.net/qq_24306353/article/details/82995320.

还有一些比如可视化,可参考:https://blog.csdn.net/xz1308579340/article/details/85622579.

源代码:https://github.com/Messi-Q/Pytorch-extract-feature/blob/master/feature_extract.py.