当前位置: 代码迷 >> 综合 >> PyTorch实战(七) - - 机器翻译(二)Seq2Seq+Attention
  详细解决方案

PyTorch实战(七) - - 机器翻译(二)Seq2Seq+Attention

热度:66   发布时间:2024-02-05 16:20:19.0

PyTorch实战(七) - - 机器翻译(二)Seq2Seq+Attention

  • 1. 任务概述
  • 2. 算法流程
  • 3. 代码实现与解析

1. 任务概述

在上篇Seq2Seq的文章中我们介绍了怎么用encoder-decoder框架实现机器翻译任务,现在加上注意力机制

2. 算法流程

  • Encoder(x,x_len):return encoder_output,encoder_hid
  • Context=encoder_output,context_len=x_len
  • Decoder(Context,context_len,y,y_len,encoder_hid):
  • 中间产出(y,y_len,encoder_hid):decoder_output,decoder_hid
  • Attention(decoder_output,Context,context_len):return attn, out
  • Output:out–linear-tanh
  • Output=F.log_softmax(Output)
  • 最终产出:output

3. 代码实现与解析

  • Attention
class Attention(nn.Module):def __init__(self, enc_hidden_size, dec_hidden_size):super(Attention, self).__init__()self.enc_hidden_size = enc_hidden_sizeself.dec_hidden_size = dec_hidden_sizeself.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)def forward(self, output, context, mask):# output: batch_size, output_len, dec_hidden_size# context: batch_size, context_len, 2*enc_hidden_sizebatch_size = output.size(0)output_len = output.size(1)input_len = context.size(1)context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(                batch_size, input_len, -1) # batch_size, context_len, dec_hidden_size# context_in.transpose(1,2): batch_size, dec_hidden_size, context_len # output: batch_size, output_len, dec_hidden_sizeattn = torch.bmm(output, context_in.transpose(1,2)) # batch_size, output_len, context_lenattn.data.masked_fill(mask, -1e6)attn = F.softmax(attn, dim=2) # batch_size, output_len, context_lencontext = torch.bmm(attn, context) # batch_size, output_len, enc_hidden_sizeoutput = torch.cat((context, output), dim=2) # batch_size, output_len, hidden_size*2output = output.view(batch_size*output_len, -1)output = torch.tanh(self.linear_out(output))output = output.view(batch_size, output_len, -1)return output, attn
  • Decoder层
class Decoder(nn.Module):def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):super(Decoder, self).__init__()self.embed = nn.Embedding(vocab_size, embed_size)self.attention = Attention(enc_hidden_size, dec_hidden_size)self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)self.out = nn.Linear(dec_hidden_size, vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, ctx, ctx_lengths, y, y_lengths, hid):sorted_len, sorted_idx = y_lengths.sort(0, descending=True)y_sorted = y[sorted_idx.long()]hid = hid[:, sorted_idx.long()]y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_sizepacked_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)out, hid = self.rnn(packed_seq, hid)unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)_, original_idx = sorted_idx.sort(0, descending=False)output_seq = unpacked[original_idx.long()].contiguous()hid = hid[:, original_idx.long()].contiguous()mask = self.create_mask(y_lengths, ctx_lengths)output, attn = self.attention(output_seq, ctx, mask)output = F.log_softmax(self.out(output), -1)return output, hid, attn
  • Seq2Seq
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, x, x_lengths, y, y_lengths):encoder_out, hid = self.encoder(x, x_lengths)output, hid, attn = self.decoder(ctx=encoder_out, ctx_lengths=x_lengths,y=y,y_lengths=y_lengths,hid=hid)return output, attn