当前位置: 代码迷 >> 综合 >> 6d pose estimation 之 PVnet
  详细解决方案

6d pose estimation 之 PVnet

热度:94   发布时间:2024-01-05 06:24:47.0

pvnet 源github:https://github.com/zju3dv/clean-pvnet

最近在做6D pose estimation时用了pvnet的算法,精度非常不错。想着看看能不能在pvnet上加上目标检测功能,同时判断问题是否存在,原本想着用pvnet的mask做个阈值来判断是否存在物体,结果发现阈值99%了还是无法区分出误检的区域(可能是我背景多样性做的不够)。

整体效果还可以,主要思路还是在resnet18 网络中提最后3层进行目标检测,损失采用yolov4思路。

        self.layer5_bbox = nn.Sequential(nn.Conv2d(512, 512, 3, 2, padding=1),nn.BatchNorm2d(512),nn.LeakyReLU(0.1, True))self.layer5_bbox_fc = nn.Conv2d(1024, 3 * (4 + 1 + 1), 3, 1)self.layer6_bbox = nn.Sequential(nn.Conv2d(512, 512, 3, 2, padding=2),nn.BatchNorm2d(512),nn.LeakyReLU(0.1, True))self.layer6_bbox_up = nn.UpsamplingBilinear2d(size=(38, 38))self.layer6_bbox_fc = nn.Conv2d(1536, 3 * (4 + 1 + 1), 3, 1)self.layer7_bbox = nn.Sequential(nn.Conv2d(512, 1024, 3, 2, padding=2),nn.BatchNorm2d(1024),nn.LeakyReLU(0.1, True))self.layer7_bbox_up = nn.UpsamplingBilinear2d(size=(20, 20))self.layer7_bbox_fc = nn.Conv2d(1024, 3 * (4 + 1 + 1), 3, 1)def forward(self, x, feature_alignment=False):x2s, x4s, x8s, x16s, x32s, xfc = self.resnet18_8s(x)bbox_1 = self.layer5_bbox(x32s)bbox_2 = self.layer6_bbox(bbox_1)bbox_2_up = self.layer6_bbox_up(bbox_2)bbox_3 = self.layer7_bbox(bbox_2)bbox_3_up = self.layer7_bbox_up(bbox_3)bbox_1_fc = self.layer5_bbox_fc(torch.cat([bbox_1, bbox_2_up], 1))bbox_2_fc = self.layer6_bbox_fc(torch.cat([bbox_2, bbox_3_up], 1))bbox_3_fc = self.layer7_bbox_fc(bbox_3)# x2s, x4s, x8s, _, _, xfcfm = self.conv8s(torch.cat([xfc, x8s], 1))  # 结合resnet18的 x8s(第三层) xfc(最后一层)然后用于上采样fm = self.up8sto4s(fm)  # 第一层 上采样一倍# if fm.shape[2] == 136:#     fm = nn.functional.interpolate(fm, (135, 180), mode='bilinear', align_corners=False)fm = self.conv4s(torch.cat([fm, x4s], 1))fm = self.up4sto2s(fm)  # 第二层 上采样一倍fm = self.conv2s(torch.cat([fm, x2s], 1))fm = self.up2storaw(fm)  # 第三层 上采样一倍x = self.convraw(torch.cat([fm, x], 1))seg_pred = x[:, :2, :, :]  # 前2个维度是分割# if self.seg_dim > 2:#     bounding_box = x[:, 2:self.seg_dim, :, :]  # 前2个维度是分割# else:#     bounding_box = Nonever_pred = x[:, self.seg_dim:, :, :]  # 后18个维度是 8个关键点+1个中心点, vector-field# ret = {'seg': seg_pred, 'vertex': ver_pred} # @!!!@ret = {'seg': seg_pred,'bounding_box': [bbox_1_fc, bbox_2_fc, bbox_3_fc],'vertex': ver_pred}if not self.training:with torch.no_grad():self.decode_keypoint(ret)return ret

  相关解决方案