我试图用torch.load()加载.csv文件,想想都觉得自己的操作很蠢,与其说蠢,不如说掌握的知识少。接下来进入正题:
1. torch.load的用法
torch.load()是用来加载torch.save()存储的对象的方法。
load()使用Python的unpickling工具,但是专门处理存储,它是张量的基础。他们首先在CPU上并行化,然后移动到保存它们的设备。如果失败(例如,因为运行时系统没有某些设备),就会引发异常。
所以我用DataFrame.to_csv()保存在.csv文件中的信息不能通过torch.load加载。
2. 解决方法
直接使用python自带的函数open(文件名.csv, ‘r’)来读取.csv文件,给大家展示的我读取.csv文件并获取其中字段的方法。
def load_loss_acc(file):"""从.csv文件中读取loss, acc@1和acc@5:param file: 指定.csv文件名:return: 返回loss, acc@1和acc@5"""lines = [x.strip() for x in open(file, 'r').readlines()][1:]epoch_num = [] # 存储保存的epoch值train_loss = []test_acc_1 = [] # top1准确率test_acc_5 = [] # top5准确率for l in lines:epoch_num, train_loss, test_acc_1, test_acc_5 = l.split(',')return train_loss, test_acc_1, test_acc_5```