1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
import numpy as np import math import copy import random
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.nn.init import xavier_normal_, xavier_uniform_
from multi_head_scaled_dot_product_attention import multi_head_scaled_dot_product_attention
class PositionEmbedding(nn.Module): def __init__(self, max_len, d_model): super(PositionEmbedding, self).__init__()
self.max_len = max_len self.d_model = d_model self.pe = torch.zeros(self.max_len, self.d_model, dtype = torch.float32) pos = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0)/self.d_model)) self.pe[:, 0::2] = torch.sin(pos * div_term) self.pe[:, 1::2] = torch.cos(pos * div_term)
def forward(self, x): b, l, d = x.shape assert d == self.d_model assert l <= self.max_len return x + self.pe[:l, :].to(x.device).expand_as(x).clone().detach()
class MultiheadAttention(nn.Module): def __init__(self, n_heads, d_model): super(MultiheadAttention, self).__init__()
self.n_heads = n_heads self.d_model = d_model self.q_weight = nn.Parameter(torch.empty((d_model, d_model), dtype=torch.float32), requires_grad=True) self.k_weight = nn.Parameter(torch.empty((d_model, d_model), dtype=torch.float32), requires_grad=True) self.v_weight = nn.Parameter(torch.empty((d_model, d_model), dtype=torch.float32), requires_grad=True) self.out_weight = nn.Parameter(torch.empty((d_model, d_model), dtype=torch.float32), requires_grad=True) self.q_bias = nn.Parameter(torch.empty((1, 1, d_model), dtype=torch.float32), requires_grad=True) self.k_bias = nn.Parameter(torch.empty((1, 1, d_model), dtype=torch.float32), requires_grad=True) self.v_bias = nn.Parameter(torch.empty((1, 1, d_model), dtype=torch.float32), requires_grad=True) self.out_bias = nn.Parameter(torch.empty((1, 1, d_model), dtype=torch.float32), requires_grad=True)
self._reset_parameters()
def _reset_parameters(self): xavier_uniform_(self.q_weight) xavier_uniform_(self.k_weight) xavier_uniform_(self.v_weight) xavier_uniform_(self.out_weight) xavier_normal_(self.q_bias) xavier_normal_(self.k_bias) xavier_normal_(self.v_bias) xavier_normal_(self.out_bias)
def forward(self, q, k, v, key_padding_mask = None, atten_mask = None): res, score = multi_head_scaled_dot_product_attention(q, k, v, self.n_heads, self.q_weight, self.q_bias, self.k_weight, self.k_bias, self.v_weight, self.v_bias, self.out_weight, self.out_bias, key_padding_mask=key_padding_mask, atten_mask=atten_mask) return res, score
class EncoderLayer(nn.Module): def __init__(self, n_heads, d_model, d_fc): super(EncoderLayer, self).__init__() self.self_mhsa = MultiheadAttention(n_heads, d_model) self.fc = nn.Sequential( nn.Linear(d_model, d_fc, bias=False), nn.ReLU(), nn.Linear(d_fc, d_model, bias=False) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, key_padding_mask = None, atten_mask = None): res, score = self.self_mhsa(x, x, x, key_padding_mask = key_padding_mask, atten_mask = atten_mask) res = self.norm1(x + res) res = self.norm2(x + self.fc(res)) return res, score
class DecoderLayer(nn.Module): def __init__(self, n_heads, d_model, d_fc): super().__init__() self.n_heads = n_heads self.self_atten = MultiheadAttention(n_heads, d_model) self.cross_atten = MultiheadAttention(n_heads, d_model)
self.fc = nn.Sequential( nn.Linear(d_model, d_fc, bias=False), nn.ReLU(), nn.Linear(d_fc, d_model, bias=False) )
self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model)
def forward(self, y, memory, y_key_padding_mask=None, self_atten_mask=None, memory_key_padding_mask=None, cross_atten_mask=None): res1, self_score = self.self_atten(y, y, y, key_padding_mask = y_key_padding_mask, atten_mask = self_atten_mask) res1 = self.norm1(y + res1)
res2, cross_score = self.cross_atten(res1, memory, memory, key_padding_mask = memory_key_padding_mask, atten_mask = cross_atten_mask) res2 = self.norm2(res1 + res2)
res3 = self.norm3(res2 + self.fc(res2)) return res3, self_score, cross_score
class Encoder(nn.Module): def __init__(self, n_layers, encoder_layer): super(Encoder, self).__init__() self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(n_layers)])
def forward(self, x, key_padding_mask=None, atten_mask=None): scores = [] for layer in self.layers: x, score = layer(x, key_padding_mask=key_padding_mask, atten_mask=atten_mask) scores.append(score) return x, scores
class Decoder(nn.Module): def __init__(self, n_layers, decoder_layer): super(Decoder, self).__init__() self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(n_layers)])
def forward(self, y, memory, key_padding_mask=None, self_atten_mask=None, memory_key_padding_mask=None, cross_atten_mask=None): self_scores = [] cross_scores = [] for layer in self.layers: y, self_score, cross_score = layer(y, memory, key_padding_mask, self_atten_mask, memory_key_padding_mask, cross_atten_mask) self_scores.append(self_score) cross_scores.append(cross_score) return y, self_scores, cross_scores
class Transformer(nn.Module): def __init__(self, d_model, d_fc, n_heads, n_encoder_layers, n_decoder_layers): super(Transformer, self).__init__() encoder_layer = EncoderLayer(n_heads, d_model, d_fc) self.encoder = Encoder(n_encoder_layers, encoder_layer) decoder_layer = DecoderLayer(n_heads, d_model, d_fc) self.decoder = Decoder(n_decoder_layers, decoder_layer)
def forward(self, x, y, x_key_padding_mask=None, x_self_atten_mask=None, y_key_padding_mask=None, y_self_atten_mask=None, y_mem_key_padding_mask=None, y_cross_atten_mask=None): memory, x_self_scores = self.encoder(x, x_key_padding_mask, x_self_atten_mask) y, y_self_scores, y_cross_scores = self.decoder(y, memory, y_key_padding_mask, y_self_atten_mask, y_mem_key_padding_mask, y_cross_atten_mask) return memory, y, [x_self_scores, y_self_scores, y_cross_scores]
|