当前位置: 代码迷 >> 综合 >> cvpr2019_Pyramid-Feature-Attention-Network-for-Saliency-detection 论文复现
  详细解决方案

cvpr2019_Pyramid-Feature-Attention-Network-for-Saliency-detection 论文复现

热度:59   发布时间:2024-02-28 12:39:31.0

背景介绍

这两天学习attention机制,看到了这篇文章,就看了看文章和他的代码,代码有些地方没有注明,使用起来难度较大,故在这里写一下这篇文章的代码复现。话不多说,开始展示。

论文信息

代码地址:github地址
论文地址:arxiv地址
ECSSD数据集:ECSSD dataset 地址
参考链接:github地址
参考链接:github地址

实现难点

  1. 缺少两个文件(可能是我没看懂作者的代码,比如cvpr2019——PFA里 get_list.py 可能是产生 train_pair.txt 的,但我确实没操作出来;
  2. 部分语句有点问题;
  3. python3和python2 的print语句差别
  4. 数据集不完整
    PS(我解决了这些我认为是问题的问题,可能作者在文档中有相关文件,可能是我没看到,我只是按照我的方法实现了结果复现)

问题1 缺少两个文件

从github上下载代码,点击Download ZIP
点击Download ZIP解压出来
文件目录大致如下
文件目录
这里呢,相对于下载下来的目录是多了两个文件的,换句话说,下载下来的代码中是没有以下两个文件的,用下面的代码新建一个generate_train_file.py文件即可。

  1. generate_train_file.py # 用于生成train_pair.txt
  2. train_pair.txt # 用于读入数据

generate_train_file.py 代码

generate_train_file.py 用于生成指定数据集的 train_pair.txt. 因为train.py通过调用 data.py 里面的getTrainGenerator函数实现对数据的遍历,所以针对不同的数据集,需要生成不同的train_pair.txt

import os
dataset_root = "cvpr2019_PFA/ECSSD"#此处使用的是ECSSD数据集,也可以换成其他的
img_list = []
def check_num_images():jpg_count = 0gt_count = 0for root, dirs, files in os.walk(dataset_root):for fname in files:if 'jpg' in fname:jpg_count+=1img_list.append(fname[:-4])if 'png' in fname:gt_count+=1print ("num of images: {}, num of GT maps: {}".format(jpg_count, gt_count))
check_num_images()
with open("train_pair.txt", 'w+') as fout:for img in img_list:img_path = os.path.join(dataset_root, img)fout.write(f'{img_path}.jpg {img_path}.png\n')

问题2 部分语句

把train.py文件中的这几个名字按下面对应的修改一下,有些对应是错误的,所以会引起错误。
截图来自:github地址

修改1
修改2

问题3 print语句 py3 py2的差别

把print后面的内容加个括号就行
把print后面的内容加个括号就行

问题4 载入数据

以 ECSSD dataset 为例
这里有一个问题就是,train_pair.txt中保存的名字是路径加文件名的形式,而且前面是 .jpg 图片格式的,后面是 .png 格式的;其实这里就对数据集进行了一定的处理,.jpg 是原始图像,.png是mask图像。
train_pair.txt内容如下图
train_pair.txt内容展示
可以从下图看到相关代码的引用:
data.py 中的函数
data.py 中的函数但是查看数据集的时候,发现只有.png图像,没有.jpg图像,所以一直报错 ,错误行就是下面这一行:
错误的内容大致说的是:需要输入的是tuple, but get None。

model.fit_generator(traingen, steps_per_epoch=steps_per_epoch,epochs=epochs,verbose=1,callbacks=callbacks)

原因是因为这里的traingen没有正确读入数据,就是因为数据集不是标准的读入格式。

问题4 解决办法

此处以 ECSSD 为例。
首先,下载ECSSD标准数据集。ECSSD dataset 地址
下载页面,往下滚就能看到这个
下载页面,往下滚就能看到这个
解压了能看到两个文件夹,ground_truth_mask是标准结果,images是原图
下载下来的ECSSD dataset如下图两个文件包
下载下来的ECSSD dataset
关键操作:
1.把images.zip解压
2.把images文件夹中的所有文件复制到

cvpr2019_Pyramid-Feature-Attention-Network-for-Saliency-detection-master/cvpr2019_PFA/ECSSD

这个ECSSD文件夹下,这样,ECSSD文件夹下就同时有同名的 .jpg 和 .png图片了,这样就能顺利读入数据了。
ECSSD 文件夹

结尾

安装好相应的库之后,就可以开始啦!其他数据集的话,应该应用此方法也能解决。
开始 RUN 了

biubiubiu,跑起来了,这是最开心的时候!

  相关解决方案