入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。
目录
一、构建图
二、代码解析
1、__init__
(1)参数设定
(2)Word embeddings (parallel)
(3)Transformer
2、forward
(1)Word embeddings (parallel)
(2)Transformer
(3)Parallel logits
(4)串行 or 并行输出
一、构建图
二、代码解析
这部分代码在model/gpt2_modeling.py中
1、__init__
(1)参数设定
- num_layers:transformerLayer的层数;
- vocab_size:词典大小;
- hidden_size:输入层大小;
- num_attention_heads:attention head的数目;
- embedding_dropout_prob:embedding的dropout概率;
- attention_dropout_prob:self attention的dropout概率;
- output_dropout_prob:输出的的dropout概率;
- max_sequence_length:最大序列长度(每次读入的序列长度);
- checkpoint_activations:是否启用检查点激活;
- checkpoint_num_layers:checkpoint层数;
- parallel_output:output是串行or并行;
- query_window:稀疏处理的窗口大小;
- key_window_times:用于调节稀疏处理中的窗口数量;
- num_pivot:稀疏处理中的token总数;
class GPT2Model(torch.nn.Module):"""GPT-2 Language model.The output of the forward method are the logits (parallel orserial depending on the `parallel_output` flag."""def __init__(self,num_layers,vocab_size,hidden_size,num_attention_heads,embedding_dropout_prob,attention_dropout_prob,output_dropout_prob,max_sequence_length,max_memory_length,checkpoint_activations,checkpoint_num_layers=1,parallel_output=True,query_window=128,key_window_times=6,num_pivot=768):super(GPT2Model, self).__init__()self.parallel_output = parallel_outputinit_method = init_method_normal(std=0.02)#初始化方法为高斯分布(均值为0,方差为0.02)
(2)Word embeddings (parallel)
# Word embeddings (parallel).self.word_embeddings = mpu.VocabParallelEmbedding(vocab_size, hidden_size, init_method=init_method)
详见CogView中的Word embeddings (parallel)_tt丫的博客-CSDN博客
(3)Transformer
# Transformerself.transformer = mpu.GPT2ParallelTransformer(num_layers,hidden_size,num_attention_heads,max_sequence_length,max_memory_length,embedding_dropout_prob,attention_dropout_prob,output_dropout_prob,checkpoint_activations,checkpoint_num_layers,query_window=query_window,key_window_times=key_window_times,num_pivot=num_pivot)
详见CogView中的Transformer_tt丫的博客-CSDN博客
2、forward
def forward(self, input_ids, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems):
(1)Word embeddings (parallel)
shape为(b,s,h)
补:b——batch size;s——sequence length;h——hidden_size;
# Embeddings.words_embeddings = self.word_embeddings(input_ids)embeddings = words_embeddings
(2)Transformer
# Transformer.transformer_output = self.transformer(embeddings, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse, *mems)logits, *hidden_layers = transformer_output#logits为output;*hidden_layers为*mem
(3)Parallel logits
# Parallel logits.logits_parallel = mpu.copy_to_model_parallel_region(logits)#传递到模型并行区域logits_parallel = F.linear(logits_parallel,self.word_embeddings.weight)#线性变化
最终shape为(b,s,h)*(v/p,h)^T=(b,s,v/p)
v——vocab_size;p——number of partitions;
(4)串行 or 并行输出
if self.parallel_output:#并行return (logits_parallel, *hidden_layers)return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)#串行
欢迎大家在评论区批评指正,谢谢~