ChatGLM-6B论文代码笔记

news/2024/5/1 13:02:21/文章来源:https://blog.csdn.net/weixin_46133588/article/details/129975428

ChatGLM-6B


文章目录

  • ChatGLM-6B
  • 前言
  • 一、原理
    • 1.1 优势
    • 1.2 实验
    • 1.3 特点:
    • 1.4 相关知识点
  • 二、实验
    • 2.1 环境基础
    • 2.2 构建环境
    • 2.3 安装依赖
    • 2.4 运行
    • 2.5 数据
    • 2.6 构建前端页面
  • 3 总结


前言

Github:https://github.com/THUDM/ChatGLM-6B
参考链接:
https://chatglm.cn/blog


一、原理

1.1 优势

开源

1.2 实验

在这里插入图片描述
在这里插入图片描述

1.3 特点:

优点:

  • 充分的中英双语预训练: ChatGLM-6B 在 1:1 比例的中英语料上训练了 1T 的 token 量,兼具双语能力。
  • 优化的模型架构和大小: 吸取 GLM-130B 训练经验,修正了二维 RoPE 位置编码实现,使用传统FFN结构。6B(62亿)的参数大小,也使得研究者和个人开发者自己微调和部署 ChatGLM-6B 成为可能。
  • 较低的部署门槛: FP16 半精度下,ChatGLM-6B 需要至少 13GB 的显存进行推理,结合模型量化技术,这一需求可以进一步降低到 10GB(INT8) 和 6GB(INT4), 使得 ChatGLM-6B 可以部署在消费级显卡上。
  • 更长的序列长度: 相比 GLM-10B(序列长度1024),ChatGLM-6B 序列长度达 2048,支持更长对话和应用。
  • 人类意图对齐训练: 使用了监督微调(Supervised Fine-Tuning)、反馈自助(Feedback Bootstrap)、人类反馈强化学习(Reinforcement Learning from Human Feedback) 等方式,使模型初具理解人类指令意图的能力。输出格式为 markdown,方便展示。

缺点:

  • 模型容量较小: 6B 的小容量,决定了其相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B 可能会生成不正确的信息;她也不擅长逻辑类问题(如数学、编程)的解答。
  • 可能会产生有害说明或有偏见的内容:ChatGLM-6B 只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。
  • 较弱的多轮对话能力:ChatGLM-6B 的上下文理解能力还不够充分,在面对长答案生成,以及多轮对话的场景时,可能会出现上下文丢失和理解错误的情况。
  • 英文能力不足:训练时使用的指示大部分都是中文的,只有一小部分指示是英文的。因此在使用英文指示时,回复的质量可能不如中文指示的回复,甚至与中文指示下的回复矛盾。
  • 易被误导:ChatGLM-6B 的“自我认知”可能存在问题,很容易被误导并产生错误的言论。例如当前版本模型在被误导的情况下,会在自我认知上发生偏差。即使该模型经过了1万亿标识符(token)左右的双语预训练,并且进行了指令微调和人类反馈强化学习(RLHF),但是因为模型容量较小,所以在某些指示下可能会产生有误导性的内容。

1.4 相关知识点

P-tuning的原理, 论文的原理比较简单,

    def enable_input_require_grads(self):"""Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keepingthe model weights fixed."""def make_inputs_require_grads(module, input, output):output.requires_grad_(True)self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

在这里插入图片描述
蓝色部分都是冻结的,橙色部分是可训练的参数。
核心代码展示:
在这里插入图片描述
layer_past 是 paper中的layer prompt[i], 具体来说就是参与到content_vector的计算中了
其核心attention的计算如下:

def attention_fn(self,query_layer,key_layer,value_layer,attention_mask,hidden_size_per_partition,layer_id,layer_past=None,   就是layer prompt[i]scaling_attention_score=True,use_cache=False,
):if layer_past is not None:past_key, past_value = layer_past[0], layer_past[1]key_layer = torch.cat((past_key, key_layer), dim=0)value_layer = torch.cat((past_value, value_layer), dim=0)# seqlen, batch, num_attention_heads, hidden_size_per_attention_headseq_len, b, nh, hidden_size = key_layer.shapeif use_cache:present = (key_layer, value_layer)else:present = Nonequery_key_layer_scaling_coeff = float(layer_id + 1)if scaling_attention_score:query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)# ===================================# Raw attention scores. [b, np, s, s]# ===================================# [b, np, sq, sk]output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))# [sq, b, np, hn] -> [sq, b * np, hn]query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)# [sk, b, np, hn] -> [sk, b * np, hn]key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)matmul_result = torch.empty(output_size[0] * output_size[1],output_size[2],output_size[3],dtype=query_layer.dtype,device=query_layer.device,)matmul_result = torch.baddbmm(matmul_result,query_layer.transpose(0, 1),  # [b * np, sq, hn]key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]beta=0.0,alpha=1.0,)# change view to [b, np, sq, sk]attention_scores = matmul_result.view(*output_size)if self.scale_mask_softmax:self.scale_mask_softmax.scale = query_key_layer_scaling_coeffattention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())else:if not (attention_mask == 0).all():# if auto-regressive, skipattention_scores.masked_fill_(attention_mask, -10000.0)dtype = attention_scores.dtypeattention_scores = attention_scores.float()attention_scores = attention_scores * query_key_layer_scaling_coeffattention_probs = F.softmax(attention_scores, dim=-1)attention_probs = attention_probs.type(dtype)# =========================# Context layer. [sq, b, hp]# =========================# value_layer -> context layer.# [sk, b, np, hn] --> [b, np, sq, hn]# context layer shape: [b, np, sq, hn]output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))# change view [sk, b * np, hn]value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)# change view [b * np, sq, sk]attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)# matmul: [b * np, sq, hn]context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))# change view [b, np, sq, hn]context_layer = context_layer.view(*output_size)# [b, np, sq, hn] --> [sq, b, np, hn]context_layer = context_layer.permute(2, 0, 1, 3).contiguous()# [sq, b, np, hn] --> [sq, b, hp]new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)context_layer = context_layer.view(*new_context_layer_shape)outputs = (context_layer, present, attention_probs)return outputs

二、实验

2.1 环境基础

在这里插入图片描述
在这里插入图片描述

2.2 构建环境

conda create -n py310_chat python=3.10       # 创建新环境
source activate py310_chat                   # 激活环境git clone https://github.com/THUDM/ChatGLM-6B.git
cd ChatGLM-6B

2.3 安装依赖

pip install -r requirements.txt
pip install rouge_chinese nltk jieba datasets

2.4 运行

$ cd ptuning/$ sed -i 's/\r//' train.sh$ bash train.sh

train.sh 参数

--do_train
--train_file
AdvertiseGen/train.json
--validation_file
AdvertiseGen/dev.json
--prompt_column
content
--response_column
summary
--overwrite_cache
--model_name_or_path
../chatglm-6b
--output_dir
output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR
--max_source_length
64
--max_target_length
64
--per_device_train_batch_size
16
--per_device_eval_batch_size
1
--gradient_accumulation_steps
2
--predict_with_generate
--max_steps
3000
--logging_steps
10
--save_steps
1000
--learning_rate
1e-2
--pre_seq_len
512

复现结果:
在这里插入图片描述

2.5 数据

prompt = ‘类型#裤版型#宽松风格#性感图案#线条裤型#阔腿裤’
answer = ‘宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。’

2.6 构建前端页面

首先安装 Gradio:pip install gradio,然后运行仓库中的 web_demo.py:

python web_demo.py

程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 demo.queue().launch(share=True, inbrowser=True,server_name="0.0.0.0", server_port=1902) 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 share=False,如有需要公网访问的需求,可以重新修改为 share=True 启动。

3 总结

p-tuning-v2, 只训练prefix embedding,其余的都fixed住。数据还只是单轮的对话。虽然多轮可以直接使用concate上下文进行,这也只是暂时的猜想,后续RLHF如何加入。这里解决的是:

  • glm的架构图
  • transformer.trainer
  • ptuning
  • gradio前端界面
  • FP4、8、16量化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_96755.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPSS【实践 01】Developing a Greenplum Streaming Server Client 自定义GPSS客户端开发实例

自定义GPSS客户端开发流程1.GPSS是什么2.架构3.组件下载安装4.自定义客户端4.1 GPSS Batch Data API Service Definition4.2 Setting up a Java Development Environment4.3 Generating the Batch Data API Client Classes4.4 Coding the GPSS Batch Data Client4.4.1 Connect …

精准关键词获取-行业搜索词分析

SEO关键词的收集通常可以通过以下几种方法: 根据市场价值、搜索词竞争性和企业实际产品特征进行筛选:确定您的关键词列表之前,建议先进行市场分析,了解您的竞争对手、行业状况和目标受众等信息,以更好的了解所需的特定…

为何ChatGPT如此擅长编造故事?

“幻觉”——人工智能中的一个偏见性术语 AI聊天机器人(如OpenAI的ChatGPT)依赖于一种称为“大型语言模型”(LLM)的人工智能来生成它们的响应。LLM是一种计算机程序,经过数百万文本源的训练,可以阅读并生成“自然语言”文本语言,就像人类自然…

HTTP协议概述 | 简析HTTP请求流程 | HTTP8种请求方法

目录 🌏 HTTP的简单介绍 何为HTTP HTTP1.0与HTTP1.1 🌏 HTTP的请求方法 1、OPTIONS 2、HEAD 3、GET 4、POST 5、PUT 6、DELETE 7、TRACE 8、CONNECT 🌏 HTTP的工作原理 🌏 HTTP请求/响应的步骤 1、客户端连接到Web…

【Linux】用户命令(创建,修改,切换,删除,密码)

目录 1.创建 查看用户信息 查看id 2.修改 修改用户名 修改用户uid 操作前: 操作后 修改组名 操作前: 操作后: 修改组id 操作前: 操作后: 操作前: 操作后: 3.切换用户 4.删除 操作前: 操作…

LeetCode:376. 摆动序列——说什么贪心和动规~

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀算法专栏: 👉🏻123 一、🌱376. 摆动序列 题目描述:如果连续数字之间的差严格地在正数和…

php7类型约束,严格模式

在PHP7之前,函数和类方法不需要声明变量类型 ,任何数据都可以被传递和返回,导致几乎大部分的调用操作都要判断返回的数据类型是否合格。 为了解决这个问题,PHP7引入了类型声明。 目前有两类变量可以声明类型: 形参&a…

拼多多运营中需要采集淘宝天猫京东平台商品详情页面数据上架拼多多店铺,如何使用技术封装接口实现

业务背景:电商平台趋势,平台化。大家可以看到大的电商都开始有自己的平台,其实这个道理很清楚,就是因为这是充分利用自己的流量、自己的商品和服务大效益化的一个过程,因为有平台,可以利用全社会的资源弥补…

RPC调用框架简单介绍

一.Thrift Apache Doris目前使用的RPC调度框架。Thrift是一款基于CS(client -server)架构的RPC通信框架,开发人员可以根据定义Thrift的IDL(interface decription language)文件来定义数据结构和服务接口,灵活性高,支持…

项目5:实现数据字典的上传下载

项目5:实现数据字典的上传下载 1.什么是数据字典?如何设计? 2.业务流程逻辑 3.数据库表的设计 4.实现上传下载逻辑(前端) 5.实现上传逻辑(后端) 6.实现下载依赖(后端&#xff…

代码随想录Day49

今天继续学习动规解决完全背包问题。 322.零钱兑换 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。 计算并返回可以凑成总金额所需的最少的硬币个数 。如果没有任何一种硬币组合能组成总金额,…

vuex中的 mapState, mapMutations

vuex中的 mapState, mapMutations Start 今天使用vuex的过程中,遇到 mapState, mapMutations 这么两个函数,今天学习一下这两个函数。 本文介绍的vuex基于 vuex3.0 1. 官方文档说明 1.1 mapState 官方解释 点击这里&#xff1…

【JUC进阶】详解synchronized锁升级

文章目录1. synchronized概述2. synchronized 的实现原理2.1 Java对象组成2.2 Monitor2.3 从字节码角度看synchronized3. 锁升级3.1 偏向锁3.2 轻量级锁1. synchronized概述 synchronized是一个悲观锁,可以实现线程同步,在多线程的环境下,需…

DIN35电压电流转频率单位脉冲输出信号变换器集电极开路隔离变送器

主要特性 将直流电压或电流信号转换成单位脉冲信号。 精度等级:0.1 级、0.2 级。产品出厂前已检验校正,用户可以直接使用。 国际标准信号输入:0-5V/0-10V/1-5V 等电压信号,0-10mA/0-20mA/4-20mA 等电流信号。 输出标准信号:0-5KHz/0-…

Flink CDC 在京东的探索与实践

摘要:本文整理自京东资深技术专家韩飞,在 Flink Forward Asia 2022 数据集成专场的分享。本篇内容主要分为四个部分: 京东自研 CDC 介绍京东场景的 Flink CDC 优化业务案例未来规划点击查看直播回放和演讲 PPT 一、京东自研 CDC 介绍 京东自研…

小白学Pytorch系列- -torch.distributions API Distributions (1)

小白学Pytorch系列- -torch.distributions API Distributions (1) 分布包包含可参数化的概率分布和抽样函数。这允许构造用于优化的随机计算图和随机梯度估计器。这个包通常遵循TensorFlow分发包的设计。 不可能通过随机样本直接反向传播。但是,有两种主要方法可以…

tomcat中出现RFC7230和RFC3986问题解析

问题截图 问题分析 出现上述问题,是因为各版本tomcat中对特殊字符和请求路径中携带中文参数而产生的错误提示。 解决办法 1、调整tomcat版本 tomcat 7.0.76之前的版本不会出现类似问题 2、tomcat9之前,修改tomcat目录底下的/conf/catalina.properti…

chapter-5 数据库设计

以下课程来源于MOOC学习—原课程请见:数据库原理与应用 考研复习 引言 设计的时候: 我们为什么不能设计成R(学号,课程号,姓名,所咋系,系主任,成绩)? 因为存在数据冗余…

C++算法初级7——二分查找

C算法初级7——二分查找 文章目录C算法初级7——二分查找在升序的数组上进行二分查找总结应用范围应用二分查找的原理:每次排除掉一半答案,使可能的答案区间快速缩小。 二分查找的时间复杂度:O(log n),因为每次询问会使可行区间的…

appium+python自动化测试启动app

一、部署环境 1、依次下载安装以下工具,并配置环境变量: android sdk Nodejs appium appium-doctor Appium-Python-Client pycharm64 ps:安装包下载和配置环境变量的操作步骤跟着网上各路大神的帖子一步一步做就好了,没啥难度 二、连…