在Libtorch中像numpy一样过滤矩阵中的值
分类
这个简单,自己随意处理即可。
语义分割
语义分割中,我们对模型输出结果会需要简单进行阈值过滤,在numpy中,我们会这样做:
output[output < 0.5] = 0
这样我们就过滤了我们不需要的值,但是,我们怎么在libtorch中做呢?libtorch的api看了不少,就是没发现什么好方法。后来在pytorch论坛提问,终于找到了灵感,下面就是解决方案:
假设一个 {1, 10, 224, 224} 的 output,我们会按照如下流程来做
auto mask = torch::softmax(output, 1);
mask.index_put_({mask < 0.5}, 0);
这样就过滤掉了你 <0.5 的值了。
官方文档太不健全,看API真的很辛苦。