当前位置: 代码迷 >> 综合 >> torch.dataset的构建
  详细解决方案

torch.dataset的构建

热度:46   发布时间:2023-11-13 19:33:20.0

1.数据转换类(TranslateData)

import random
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import sys
sys.path.append('.')
from vocabField import VocabField
# translate_data在dataset中使用
# collate_fn在dataloader中使用
class TranslateData():def __init__(self,pad = 0):self.pad = paddef collate_fn(self,batch):src = list(map(lambda x:x['src'],batch))tgt = list(map(lambda x:x['tgt'],batch))src_len = list(map(lambda x:x['src_len'],batch))tgt_len = list(map(lambda x:x['tgt_len'],batch))src = torch.transpose(pad_sequence(src,padding_value = self.pad),0,1)tgt = torch.transpose(pad_sequence(tgt,padding_value = self.pad),0,1)src_len = torch.stack(src_len)tgt_len = torch.stack(tgt_len)return {
    'src':src,'tgt':tgt,'src_len':src_len,'tgt_len':tgt_len}def translate_data(self,subs,obj):import reimport unicodedatadef unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn')def normalizeString(s):s = unicodeToAscii(s.lower().strip())s = re.sub(r'(.!?)',r'\1',s)s = re.sub(r'[^a-zA-Z.!?]+',r' ',s)return ssrc,tgt = subssrc = normalizeString(src).split(' ')tgt = normalizeString(tgt).split(' ')tgt = [obj.tgt_vocab.sos_token] + tgt + [obj.tgt_vocab.eos_token]if len(src) > obj.max_src_length or len(tgt) > obj.max_tgt_length:return Nonesrc_length,tgt_length = len(src),len(tgt)src_ids = [obj.src_vocab.word2idx[w] for w in src]tgt_ids = [obj.tgt_vocab.word2idx[w] for w in tgt]return {
    'src':torch.LongTensor(src_ids),'tgt':torch.LongTensor(tgt_ids),'src_len':torch.LongTensor([src_length]),'tgt_len':torch.LongTensor([tgt_length])}

2.Dataset类(DialogDataset)

class DialogDataset(Dataset):def __init__(self,data_fp,transform_fuc,src_vocab,tgt_vocab,max_src_length,max_tgt_length):self.datasets = []self.src_vocab = src_vocabself.tgt_vocab = tgt_vocabself.max_src_length = max_src_lengthself.max_tgt_length = max_tgt_lengthloaded = 0data_monitor = 0with open(data_fp,'r') as f:for line in tqdm(f,desc = 'Load Data:'):subs = line.strip().split('\t')loaded += 1if not data_monitor:data_monitor = len(subs)else:assert data_monitor == len(subs)item = transform_fuc(subs,self)if item:self.datasets.append(item)print(f"{
      loaded} paris loaded. {
      len(self.datasets)} are valid. Rate {
      1.0 * len(self.datasets)/loaded:.4f}")def __len__(self):return len(self.datasets)def __getitem__(self,idx):return self.datasets[idx]

3.测试

train_path = '../../data/fra2eng/fra_eng.dev'
dev_path = '../../data/fra2eng/fra_eng.dev'
src_vocab_file = '../../data/fra2eng/src_vocab_file'
tgt_vocab_file = '../../data/fra2eng/tgt_vocab_file'
src_vocab_size = 40000
tgt_vocab_size = 40000
max_src_length = 50
max_tgt_length = 50
batch_size = 20
src_vocab_list = VocabField.load_vocab(src_vocab_file)
tgt_vocab_list = VocabField.load_vocab(tgt_vocab_file)
src_vocab = VocabField(src_vocab_list,vocab_size = src_vocab_size)
tgt_vocab = VocabField(tgt_vocab_list,vocab_size = tgt_vocab_size)
pad_id = tgt_vocab.word2idx[tgt_vocab.pad_token]
trans_data = TranslateData()
train_set = DialogDataset(train_path,trans_data.translate_data,src_vocab,tgt_vocab,max_src_length = max_src_length,max_tgt_length = max_tgt_length
)
trainloader = DataLoader(train_set,batch_size = 20,shuffle = False,drop_last = True,collate_fn = trans_data.collate_fn
)
#list(trainloader)[0]
dev_set = DialogDataset(dev_path,trans_data.translate_data,src_vocab,tgt_vocab,max_src_length = max_src_length,max_tgt_length = max_tgt_length
)
dev_loader = DataLoader(dev_set,batch_size = 15,shuffle = False,collate_fn = trans_data.collate_fn
)
  相关解决方案