当前位置: 代码迷 >> 综合 >> torch.masked_select 用于带mask的图像分割测试集DRIVE
  详细解决方案

torch.masked_select 用于带mask的图像分割测试集DRIVE

热度:93   发布时间:2023-11-21 02:42:56.0

类似与DRIVE这样的数据集,输入图像有一个非矩形的边界,并非我们图像分割感兴趣的部分,数据集给出了二值mask,用于屏蔽

网络结果向前传播后得到output,在training中与grand truth计算loss,在test中计算准确度指数,显然都需要忽略在mask黑色区域,test中相对容易解决mask的使用问题,那么training中怎么忽略mask黑色区域的损失呢?


torch.masked_select(很遗憾pytorch的官网手册被墙了),一下先搬运原文

torch.masked_select(input, mask, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the binary mask mask which is a ByteTensor.

The shapes of the mask tensor and the input tensor don’t need to match, but they mmaskust be broadcastable. s


mask表示一个dtype=ByteTensor的tensor,用mask=mask.byte()即可转化,但mask只能取0,1值

mask的shape与input一致这种情况很好理解,返回一个一维tensor,表示input中对应位置上mask值为1的tensor

mask的shape与input不一致时,比如dim数量不相等,维度长度不相等。显然需要在某种条件下,将mask唯一扩充为input的shape的方法,这就是文中的broadcasting字眼,下面是原文描述

Two tensors are “broadcastable” if the following rules hold:

  • Each tensor has at least one dimension.
  • When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
  1. 维数短的一方在torch前部,注意是前面插入size为1的维至维数一致
  2. 逐维比较长度,如果不相等,则必需有一方在该维度长为1,然后在这个维度上复制叠加,至长度相等,否则失败

input 与mask满足broadcasting时就不会报错,一个小细节:

当mask是那个维度长的一方时,多出来的维度会自动忽略

  相关解决方案