手搓 自然语言模型 各种对比数据

news/2024/4/27 10:35:45/文章来源:https://blog.csdn.net/weixin_32759777/article/details/132029801

基础模型和设计思想
最优网络结构

import paddle
import numpy as np
from tqdm import  tqdm
class EmMask(paddle.nn.Layer):def __init__(self, voc_size=19, hidden_size=256, max_len=48):super(EmMask, self).__init__()# 定义输入序列和标签序列self.embedding_layer = paddle.nn.Embedding(voc_size, hidden_size)self.pos_em_layer = paddle.nn.Embedding(max_len, hidden_size)self.pos_to_down = paddle.nn.Linear(hidden_size, 1)self.sample_buffer_data=paddle.zeros([1])# 定义模型计算过程def forward(self, x):# 将输入序列嵌入为向量表示embedded_x = self.embedding_layer(x)  # bs--->bsh# embedded_x  += paddle.fft.fft(embedded_x, axis=1).real()# embedded_p 有权重 后期预测的时候就要参与 这样会造成计算量增加 如果使用 1 代替 减少多样性# 但是使用pos 是 对于任何输入是固定的可以事先弄好的可以事先计算,一个固定的w 而已# 而当前的attention 这个参数是动态的,要通过其他方法来实现动态的 比如scale 多头等# 当前这种方式全靠 开头和结尾 中间固定参数哦 如果使用多个 加上softmax 那么就能完成多头scale 的操作了embedded_p = self.pos_em_layer(paddle.arange(1, x.shape[1] + 1).astype("int64"))embedded_p = self.pos_to_down(embedded_p)xp = embedded_x.transpose([0, 2, 1]).unsqueeze(3) @ embedded_p.transpose([1, 0])# maskmask = paddle.triu(paddle.ones([xp.shape[-1], xp.shape[-1]]))x = xp * maskreturn xclass JustMaskEm(paddle.nn.Layer):def __init__(self, voc_size=19, hidden_size=512, max_len=1024):super(JustMaskEm, self).__init__()# 定义输入序列和标签序列self.em_mask_one = paddle.nn.Embedding(voc_size, hidden_size)self.em_mask_two = EmMask(voc_size, hidden_size, max_len)self.head_layer = paddle.nn.Linear(hidden_size, voc_size,bias_attr=False)self.layer_nor = paddle.nn.LayerNorm(hidden_size)# 定义模型计算过程def forward(self, x):one = self.em_mask_one(x)two = self.em_mask_two(x)x = one* paddle.sum(two, -2).transpose([0,2,1])# x = paddle.sum(x, -2)# x=x.transpose([0, 2, 1])# x = self.head_layer(self.layer_nor(x))x = self.head_layer(self.layer_nor(x))return x# 进行模型训练和预测
# if __name__ == '__main__':
#     net = JustMaskEm()
#     X = paddle.to_tensor([
#         [1, 2, 3, 4],
#         [5, 6, 7, 8]
#     ], dtype='int64')
#     print(net(X).shape)
#     print(net.sample_buffer(X).shape)#
def train_data():net = JustMaskEm(voc_size=len(voc_id))net.load_dict(paddle.load("long_attention_model"))print("加载成功")opt = paddle.optimizer.Adam(parameters=net.parameters(), learning_rate=0.0003)loss_f = paddle.nn.CrossEntropyLoss()loss_avg = []acc_avg = []batch_size = 1000*3bar=tqdm(range(1, 3 * 600))for epoch in bar:np.random.shuffle(data_set)for i, j in [[i, i + batch_size] for i in range(0, len(data_set), batch_size)]:one_data = data_set[i:j]if (len(acc_avg) + 1) % 1000 == 0:# print(np.mean(loss_avg), "____", np.mean(acc_avg))paddle.save(net.state_dict(), "long_attention_model")paddle.save({"data": loss_avg}, "loss_avg")paddle.save({"data": acc_avg}, "acc_avg")one_data = paddle.to_tensor(one_data)in_put = one_data[:, :-1]label = one_data[:, 1:]# label = one_data[:, 1:]out = net(in_put)loss = loss_f(out.reshape([-1, out.shape[-1]]), label.reshape([-1]).astype("int64"))acc = np.mean((paddle.argmax(out, -1)[:, :].reshape([-1]) == label[:, :].reshape([-1])).numpy())# loss = loss_f(out, label.reshape([-1]).astype("int64"))# acc = np.mean((paddle.argmax(out, -1) == label.reshape([-1])).numpy())loss_data = loss.numpy()[0]acc_avg.append(acc)loss_avg.append(loss_data)bar.set_description(desc="{}{}{}{}{}".format(epoch, "____", np.mean(loss_avg), "____", np.mean(acc_avg)))opt.clear_grad()loss.backward()opt.step()if np.mean(acc_avg) > 0.80:opt.set_lr(opt.get_lr() / (np.mean(acc_avg) * 100 + 1))print(np.mean(loss_avg), "____", np.mean(acc_avg))paddle.save(net.state_dict(), "long_attention_model")paddle.save({"data": loss_avg}, "loss_avg")paddle.save({"data": acc_avg}, "acc_avg")if __name__ == "__main__":with open("poetrySong.txt", "r", encoding="utf-8") as f:data1 = f.readlines()data1 = [i.strip().split("::")[-1] for i in data1 if len(i.strip().split("::")[-1]) == 32]voc_id = ["sos"] + sorted(set(np.hstack([list(set(list("".join(i.split())))) for i in data1]))) + ["pad"]data_set = [[voc_id.index(j) for j in i] for i in data1]train_data()

实验对比数据
在这里插入图片描述
两种基本网络结构设计
在这里插入图片描述
在这里插入图片描述

总结

从上面实验数据可知 在使用方案 二的时候 ,如代码写 不断的扩大维度方可提高收敛时候的acc 上限且最高

且该网络模型可以在推理的时候如最后一幅图所示可以,进行单独解码 从而节约算力。

注意:
后面两幅图中 带框的两个是两个不同的方案,不带框的是公共部分
经过测试抛弃了蓝色框的方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_338075.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot集成单元测试调用dao,service

文章目录 Spring Boot集成单元测试调用dao&#xff0c;service1 添加相关依赖2 新建测试类 Spring Boot集成单元测试调用dao&#xff0c;service 1 添加相关依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-st…

Leetcode 144. 二叉树的前序遍历

题目描述 题目链接&#xff1a;https://leetcode.cn/problems/binary-tree-preorder-traversal/description/ 代码实现 class Solution {List<Integer> l new ArrayList<>();public List<Integer> preorderTraversal(TreeNode root) {preoder(root);re…

音频编辑必备技能:怎么将音频转换mp3

丽萨&#xff1a;嘿&#xff0c;听说你最近在研究音频格式转换的方法&#xff0c;有眉目了吗&#xff1f; 凯瑞&#xff1a;没错&#xff0c;我下载了很多高清音乐&#xff0c;发现有些格式的音频文件在我的播放器上打不开&#xff0c;所以想一个转换工具。但是网上软件太多&a…

ElasticSearch基础篇-Java API操作

ElasticSearch基础-Java API操作 演示代码 创建连接 POM依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:sch…

基于C语言 --- 自己写一个三子棋小游戏

C语言程序设计笔记---019 初阶三子棋小游戏(开源)1、arr_main.c程序大纲2、arr_game1.h3、arr_game1.c3.1、 自定义初识化函数 InitBoard( ) 和 自定义显示函数 DisPlayBoard( )3.2、 自定义玩家下棋函数 PlayerMove( )3.4、 自定义电脑下棋函数 ComputerMove( )3.5、 输赢判断…

飞致云开源社区月度动态报告(2023年7月)

自2023年6月起&#xff0c;中国领先的开源软件公司FIT2CLOUD飞致云将以月度为单位发布《飞致云开源社区月度动态报告》&#xff0c;旨在向广大社区用户同步飞致云旗下系列开源软件的发展情况&#xff0c;以及当月主要的产品新版本发布、社区运营成果等相关信息。 飞致云开源大…

c++ | 动态链接库 | 小结

//环境 linux c //生成动态链接库 //然后调用动态链接库中的函数//出现的问题以及解决//注意在win和在linux中调用动态链接库的函数是不一样的//在要生成链接库的cpp文件中比如以后要调用本文件中的某个函数&#xff0c;需要extern "c" 把你定的函数“再封装”避免重…

Postgresql源码(109)并行框架实例与分析

1 PostgreSQL并行参数 系统参数 系统总worker限制&#xff1a;max_worker_processes 默认8 系统总并发限制&#xff1a;max_parallel_workers 默认8 单Query限制&#xff1a;max_parallel_workers_per_gather 默认2 表参数限制&#xff1a;parallel_workers alter table tbl …

4090Ti被取消,NVIDIA还要推出新“甜品卡“

不知不觉距离 NVIDIA RTX 40 系显卡发布已快一年&#xff0c;4090 到 4060 从旗舰到甜品也都差不多了。 不过每个男孩子都想要的礼物 - RTX 4090 Ti &#xff0c;至今仅在春晚发布。 从核心架构上来看&#xff0c;RTX 4090 上的 AD 102-300 也确实不是完全体。 仅拥有144组 S…

适配器模式与装饰器模式对比分析:优雅解决软件设计中的复杂性

适配器模式与装饰器模式对比分析&#xff1a;优雅解决软件设计中的复杂性 在软件设计中&#xff0c;我们常常面临着需要将不同接口或类协调工作的情况&#xff0c;同时还要满足灵活性和可扩展性的需求。为了应对这些挑战&#xff0c;适配器模式和装饰器模式应运而生&#xff0c…

【计算机视觉】BLIP:源代码示例demo(含源代码)

文章目录 一、Image Captioning二、VQA三、Feature Extraction四、Image-Text Matching 一、Image Captioning 首先配置代码&#xff1a; import sys if google.colab in sys.modules:print(Running in Colab.)!pip3 install transformers4.15.0 timm0.4.12 fairscale0.4.4!g…

linux备份与还原系统(类似window上ghost备份还原)

一、摘要 在linux上进行了几年的开发工作 &#xff08;qt ros&#xff09; 突然发现&#xff0c;现在有公司硬件、笔记本台式机一台占一个系统&#xff0c;导致硬件太浪费&#xff0c;又不能用虚拟机&#xff08;有时候要链接硬件必须物理机&#xff09;怎么办&#xff1f; 二…

Spring框架中的Bean的各种加载方式

大家好&#xff0c;这里向大家主要介绍Spring框架以及SpringBoot框架中的Bean的各种加载方式&#xff0c;有时候我们的学习&#xff0c;就是单纯为了工作效率而作为工具使用&#xff0c;于是乎&#xff0c;往往忽略了其最重要的一点&#xff0c;那就是底层原理&#xff01;所以…

什么是MES,什么是WMS,MES与WMS有什么区别?

什么是MES&#xff1f;什么是WMS&#xff1f;以及MES&#xff08;制造执行系统&#xff09;与WMS&#xff08;仓库管理系统&#xff09;的区别&#xff0c;下面分为三块跟大家详细讲解。 一、什么是MES&#xff1f; 1、概念&#xff1a; MES&#xff08;英文全称&#xff1a…

蓝桥杯2018省赛全球变暖dfs

全球变暖 问题描述格式输入格式输出样例输入样例输出评测用例规模与约定解析参考程序 问题描述 格式输入 格式输出 输出一个整数 样例输入 样例输出 1 评测用例规模与约定 最大运行时间&#xff1a;1s最大运行内存: 256M 解析 采用dfs的方式进行搜索&#xff0c;首先输入地…

有点慌,新公司项目构建用的Gradle

入职新公司&#xff0c;构建项目的工具用的gradle&#xff0c;以前没用过&#xff0c;看到一个build.gradle&#xff0c;点进去&#xff0c;心里一句我曹&#xff0c;这写的都是些什么玩意&#xff0c;方得一批&#xff0c;赶紧去补了下课。 好吧&#xff0c;先学点语法&#…

HTML+CSS前端 动态响应用户登录界面

day2 知道了动态响应设计的概念&#xff0c;在原先登录界面的基础上进行升级 动态响应 由于前端页面需要在不同大小和分辨率的屏幕上显示&#xff0c;所以需要它具有动态适应的特性。 常用的方式是在 css 文件中用 media 动态查询&#xff0c;同时使用 flex 弹性布局。 例如&a…

Java集合篇

前言&#xff1a;笔者参考了JavaGuide、三分恶等博主的八股文&#xff0c;结合Chat老师和自己的理解&#xff0c;整理了一篇关于Java集合的八股文。希望对各位读者有所帮助~~ 引言 常见集合有哪些&#xff1f; Java集合相关类和接口都在java.util包中&#xff0c;按照其存储…

国内外遥感数据处理软件对比

1.国内遥感数据处理软件概况 1.1北京航天宏图信息技术股份有限公司 1.1.1公司简介 航天宏图信息技术股份有限公司成立于2008年,是国内遥感和北斗导航卫星应用服务商,致力于卫星应用软件国产化、行业应用产业化、应用服务商业化,研发并掌握了具有完全自主知识产权的PIE(Pix…

Python源码:Tkinter组件布局管理的3种方式

Tkinter组件布局管理可以使用pack()方法、grid()方法和place()方法。pack()方法将组件放置在窗口中&#xff0c;grid()方法将组件放置在网格布局中&#xff0c;place()方法将组件放置在指定位置。 01使用pack()方法布局&#xff1a; 在Tkinter中&#xff0c;pack方法用于将控…