创新点:
- 本文根据transformer模型进行改进,提出了一个高效的模型,模型复杂度呈线性。
- 主要改进了注意力机制,出发点在于降低了注意力矩阵的重要程度,该方法采用一个(1*T)一维向量替换了原始T*T大小的注意力矩阵。
注意力结构图:
在这里,输入同样通过不同的线性映射得到Q,K,V,然后通过Q得到Q的权重:
其中,从Q到Q的权重变化过程为:
weight:(B,T,D)->(B,T,h)->(B,h,T)->(B,h,1,T)
然后Q和Q的权重做乘法运算weight*query=(B,h,1,ad)->(B,1,h,ad)->(B,1,D)->(B,T,D):
得到的结果和K做逐点乘法运算:
K的权重向量和Q的求法相同:
同样的K和K的权重做乘法运算:
最后的结果和V做逐点运算:
在这里Q和V是相同的,采用了权重共享的方法。
在espnet中的代码实现:
import numpy import torchclass FastSelfAttention(torch.nn.Module):"""Fast self-attention used in Fastformer."""def __init__(self,size,attention_heads,dropout_rate,):super().__init__()if size % attention_heads != 0:raise ValueError(f"Hidden size ({size}) is not an integer multiple "f"of attention heads ({attention_heads})")self.attention_head_size = size // attention_headsself.num_attention_heads = attention_headsself.query = torch.nn.Linear(size, size)self.query_att = torch.nn.Linear(size, attention_heads)self.key = torch.nn.Linear(size, size)self.key_att = torch.nn.Linear(size, attention_heads)self.transform = torch.nn.Linear(size, size)self.dropout = torch.nn.Dropout(dropout_rate)def espnet_initialization_fn(self):self.apply(self.init_weights)def init_weights(self, module):if isinstance(module, torch.nn.Linear):module.weight.data.normal_(mean=0.0, std=0.02)if isinstance(module, torch.nn.Linear) and module.bias is not None:module.bias.data.zero_()def transpose_for_scores(self, x):"""Reshape and transpose to compute scores.Args:x: (batch, time, size = n_heads * attn_dim)Returns:(batch, n_heads, time, attn_dim)"""new_x_shape = x.shape[:-1] + (self.num_attention_heads,self.attention_head_size,)return x.reshape(*new_x_shape).transpose(1, 2)def forward(self, xs_pad, mask):"""Forward method.Args:xs_pad: (batch, time, size = n_heads * attn_dim)mask: (batch, 1, time), nonpadding is 1, padding is 0Returns:torch.Tensor: (batch, time, size)"""batch_size, seq_len, _ = xs_pad.shapemixed_query_layer = self.query(xs_pad) # (batch, time, size)mixed_key_layer = self.key(xs_pad) # (batch, time, size)if mask is not None:mask = mask.eq(0) # padding is 1, nonpadding is 0# (batch, n_heads, time)query_for_score = (self.query_att(mixed_query_layer).transpose(1, 2)/ self.attention_head_size**0.5)if mask is not None:min_value = float(numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min)query_for_score = query_for_score.masked_fill(mask, min_value)query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)else:query_weight = torch.softmax(query_for_score, dim=-1)query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)query_layer = self.transpose_for_scores(mixed_query_layer) # (batch, n_heads, time, attn_dim) pooled_query = (torch.matmul(query_weight, query_layer).transpose(1, 2).reshape(-1, 1, self.num_attention_heads * self.attention_head_size)) # (batch, 1, size = n_heads * attn_dim)pooled_query = self.dropout(pooled_query)pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) mixed_query_key_layer = (mixed_key_layer * pooled_query_repeat) # (batch, time, size)# (batch, n_heads, time)query_key_score = (self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5).transpose(1, 2)if mask is not None:min_value = float(numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min)query_key_score = query_key_score.masked_fill(mask, min_value)query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)else:query_key_weight = torch.softmax(query_key_score, dim=-1)query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)key_layer = self.transpose_for_scores(mixed_query_key_layer) # (batch, n_heads, time, attn_dim)pooled_key = torch.matmul(query_key_weight, key_layer) # (batch, n_heads, 1, attn_dim)pooled_key = self.dropout(pooled_key)# NOTE: value = query, due to param sharingweighted_value = (pooled_key * query_layer).transpose(1, 2) # (batch, time, n_heads, attn_dim)weighted_value = weighted_value.reshape(weighted_value.shape[:-2]+ (self.num_attention_heads * self.attention_head_size,)) # (batch, time, size)weighted_value = (self.dropout(self.transform(weighted_value)) + mixed_query_layer)return weighted_value