Word2Vec实现文本识别分类

news/2024/5/11 13:21:13/文章来源:https://blog.csdn.net/qq_62904883/article/details/131729881

深度学习训练营之使用Word2Vec实现文本识别分类

  • 原文链接
  • 环境介绍
  • 前言
  • 前置工作
    • 设置GPU
    • 数据查看
    • 构建数据迭代器
  • Word2Vec的调用
  • 生成数据批次和迭代器
  • 模型训练
    • 初始化
    • 拆分数据集并进行训练
  • 预测

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第N4周:用Word2Vec实现文本分类
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.12
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前言

本次内容我本来是使用miniconda的环境的,但是好像有文件发生了损坏,出现了如下报错,据我所了解应该是某个文件发生了损坏,应该是之前将anaconda误删有关,有所了解或者有同样问题的朋友可以一起进行探讨

前置工作

设置GPU

如果

# 先进行数据加载
import torch
import torch.nn as nn
import torchvision
import os,PIL,pathlib,warnings
import time
from torchvision import transforms, datasets
from torch import nn
from torch.utils.data.dataset import random_splitwarnings.filterwarnings("ignore")#忽略警告信息
device=torch.device("cuda"if torch.cuda.is_available()else "cpu")
device

device(type=‘cpu’)

数据查看

本次使用的数据集和之前中文文本识别分类的是一样的

import pandas as pd
train_data=pd.read_csv('train.csv',sep='\t',header=None)
train_data.head()

在这里插入图片描述

构建数据迭代器

#构建数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,yx=train_data[0].values[:]
y=train_data[1].values[:]    

添加数据迭代器是为了让数据的随机性增强,进行数据集的划分,可以有效的发挥内存的高利用率

Word2Vec的调用

对Word2Vec进行直接的调用

from gensim.models.word2vec import Word2Vec
import numpy as np
#训练浅层神经网络模型
w2v=Word2Vec(vector_size=100,min_count=3)w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=30)

build_vocab统计输入每一个词汇出现的次数

def average_vec(text):vec=np.zeros(100).reshape((1,100))#表示平均向量#(n,100),其中n表示x中的元素的数量 for word in text:try:vec+=w2v.wv[word].reshape((1,100))except KeyError:continue#未找到,再进行迭代下一个词return vecx_vec=np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')

该步骤将输入的文本转变成了平均向量
对于输入进来的text当中的每一个单词都进行一个查询,确认是否当中有该词,如果有那么就将其添加到vector当中,否则跳出本层循环,查找下一个词.
最后通过np当中的concatenate方法进行一个向量的连接

train_iter=coustom_data_iter(x_vec,y)#训练迭代器
print(len(x),len(y))

12100 12100

设置训练的迭代器

label_name=list(set(train_data[1].values[:]))
print(label_name)
['FilmTele-Play', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'HomeAppliance-Control', 'Alarm-Update', 'Travel-Query', 'Video-Play', 'Calendar-Query', 'TVProgram-Play', 'Music-Play', 'Other']

生成数据批次和迭代器

text_pipeline=lambda x:average_vec(x)
label_pipeline=lambda x:label_name.index(x)
#lambda语法:lambda  arguments
text_pipeline("我想你了")

在这里插入图片描述

label_pipeline("Travel-Query")

6

这里的结果每次都会不太一样,具有一定的随机性

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list= [], []for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)# 偏移量,即语句的总词汇量label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)return text_list.to(device),label_list.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

和之前的不同在于没有了offset

模型训练

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):return self.fc(text)

初始化

num_class  = len(label_name)
vocab_size = 100000
em_size    = 12
model      = TextClassificationModel(num_class).to(device)
import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time   = time.time()for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

拆分数据集并进行训练

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 30 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss,lr))print('-' * 69)
| epoch 1 |   50/ 152 batches | train_acc 0.742 train_loss 0.02635
| epoch 1 |  100/ 152 batches | train_acc 0.820 train_loss 0.02033
| epoch 1 |  150/ 152 batches | train_acc 0.838 train_loss 0.01927
---------------------------------------------------------------------
| epoch 1 | time: 0.95s | valid_acc 0.819 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 |   50/ 152 batches | train_acc 0.850 train_loss 0.01876
| epoch 2 |  100/ 152 batches | train_acc 0.849 train_loss 0.02012
| epoch 2 |  150/ 152 batches | train_acc 0.847 train_loss 0.01736
---------------------------------------------------------------------
| epoch 2 | time: 0.92s | valid_acc 0.869 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 |   50/ 152 batches | train_acc 0.858 train_loss 0.01588
| epoch 3 |  100/ 152 batches | train_acc 0.833 train_loss 0.02008
| epoch 3 |  150/ 152 batches | train_acc 0.864 train_loss 0.01813
---------------------------------------------------------------------
| epoch 3 | time: 0.86s | valid_acc 0.835 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 |   50/ 152 batches | train_acc 0.883 train_loss 0.01309
| epoch 4 |  100/ 152 batches | train_acc 0.899 train_loss 0.00996
| epoch 4 |  150/ 152 batches | train_acc 0.895 train_loss 0.00927
---------------------------------------------------------------------
| epoch 4 | time: 0.87s | valid_acc 0.888 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 5 |   50/ 152 batches | train_acc 0.906 train_loss 0.00834
...
| epoch 30 |  150/ 152 batches | train_acc 0.900 train_loss 0.00717
---------------------------------------------------------------------
| epoch 30 | time: 0.92s | valid_acc 0.886 valid_loss 0.010 | lr 0.000000
---------------------------------------------------------------------
test_acc, test_loss = evaluate(valid_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在这里插入图片描述

预测

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text),dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
#ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to(device)print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是:Music-Play

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_331054.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pycharm新建分支并提送至GitHub

文章目录 前言pycharm创建本地分支Push至远程分支 前言 当我们写的项目代码越来越多时,一个master分支无法满足需求了,这个时候就需要创建分支来管理代码。 创建分支可以快速的回滚到某个节点的版本,也可以多个开发者同时开发一个项目&#…

剑指oferr68-II.二叉树的最近公共祖先

为什么这道题的难度是easy&#xff0c;我感觉挺难的啊&#xff0c;我想了挺久没有一点思路就直接看题解了。题解有两种解法&#xff0c;先看第一种存储父节点 class Solution {Map<Integer,TreeNode> parent new HashMap<Integer,TreeNode>();Set<Integer>…

windows nodejs 版本切换

一、按健winR弹出窗口&#xff0c;键盘输入cmd,然后敲回车。然后进入命令控制行窗口&#xff0c;并输入where node查看之前本地安装的node的路径。 二、找到上面找到的路径&#xff0c;将node.exe所在的父目录里面的所有东西都删除。 三、从官网下载安装包 https://github.com/…

Raft算法之日志复制

Raft算法之日志复制 一、日志复制大致流程 在Leader选举过程中&#xff0c;集群最终会选举出一个Leader节点&#xff0c;而集群中剩余的其他节点将会成为Follower节点。Leader节点除了向Follower节点发送心跳消息&#xff0c;还会处理客户端的请求&#xff0c;并将客户端的更…

AP5193 DC-DC宽电压LED降压恒流驱动器 LED电源驱动IC

产品 AP5193是一款PWM工作模式、外围简单、内置功率MOS管&#xff0c;适用于4.5-100V输入的高精度降压LED恒流驱动芯片。电流2.5A。AP5193可实现线性调光和PWM调光&#xff0c;线性调光脚有效电压范围0.55-2.6V.AP5193 工作频率可以通过RT 外部电阻编程来设定&#xff0c;同时…

linux下一个iic驱动(按键+点灯)-互斥

一、前提&#xff1a; 硬件部分&#xff1a; 1. rk3399开发板&#xff0c;其中的某一路iic&#xff0c;这个作为总线的主控制器 2. gd32单片机&#xff0c;其中的某一路iic&#xff0c;从设备。主要是按键上报和灯的亮灭控制。&#xff08;按键大约30个&#xff0c;灯在键的…

day34-servlet 分页

0目录 servlet 1.分页 分页逻辑1&#xff1a;数据库中20条记录&#xff0c;要求每页5条数据&#xff0c;则一共有4页 分页逻辑2&#xff1a;数据库中21条记录&#xff0c;要求每页5条数据&#xff0c;则一共有5页 分页逻辑3&#xff1a;数据库中19条记录&#xff0c;要求每页…

数字 IC 设计职位经典笔/面试题(二)

共100道经典笔试、面试题目&#xff08;文末可全领&#xff09; FPGA 中可以综合实现为 RAM/ROM/CAM 的三种资源及其注意事项&#xff1f; 三种资源&#xff1a;BLOCK RAM&#xff0c;触发器&#xff08;FF&#xff09;&#xff0c;查找表&#xff08;LUT&#xff09;&#xf…

【算法基础】2.1栈和队列(单调栈和单调队列)

文章目录 例题3302. 表达式求值&#xff08;栈的应用&#xff09;&#x1f62d;&#x1f62d;&#x1f62d;&#x1f62d;&#x1f62d;830. 单调栈知识点解法 154. 滑动窗口 &#xff08;单调队列&#xff09;知识点解法 相关链接 & 相关题目 例题 3302. 表达式求值&…

基于weka手工实现多层感知机(BPNet)

一、BP网络 1.1 单层感知机 单层感知机&#xff0c;就是只有一层神经元&#xff0c;它的模型结构如下1&#xff1a; 对于权重 w w w的更新&#xff0c;我们采用如下公式&#xff1a; w i w i Δ w i Δ w i η ( y − y ^ ) x i (1) w_iw_i\Delta w_i \\ \Delta w_i\eta(y…

RabbitMQ 同样的操作一次成功一次失败

RabbitMQ 是一个功能强大的消息队列系统&#xff0c;广泛应用于分布式系统中。然而&#xff0c;我遇到这样的情况&#xff1a;执行同样的操作&#xff0c;一次成功&#xff0c;一次失败。在本篇博文中&#xff0c;我将探讨这个问题的原因&#xff0c;并提供解决方法。 我是在表…

创作一周年纪念日【道阻且长,行则将至】

✨个人主页&#xff1a; 北 海 &#x1f389;所属专栏&#xff1a; 技术之外的往事 &#x1f383;所处时段&#xff1a; 大学生涯[1/2] 文章目录 一、起点一切皆有定数 二、成果尽心、尽力 三、相遇孤举者难起&#xff0c;众行者易趋 四、未来长风破浪会有时&#xff0c;直挂云…

markdown2html 转化流程

定义一个extensions function markedMention() {return {extensions: [{name: mention,level: inline,start(src) {// console.log("markedMention start....", src);return src.indexOf(#)},tokenizer(src, tokens) {const rule /^(#[a-zA-Z0-9])\s?/const match…

JMeter websocket接口测试

前言 在一个网站中&#xff0c;很多数据需要即时更新&#xff0c;比如期货交易类的用户资产。在以前&#xff0c;这种功能的实现一般使用http轮询&#xff0c;即客户端用定时任务每隔一段时间向服务器发送查询请求来获取最新值。这种方式的弊端显而易见&#xff1a; 有可能造…

解决Gson解析json字符串,Integer变为Double类型的问题

直接上代码记录下。我代码里没有Gson包&#xff0c;用的是nacos对Gson的封装&#xff0c;只是包不同&#xff0c;方法都一样 import com.alibaba.nacos.shaded.com.google.common.reflect.TypeToken; import com.alibaba.nacos.shaded.com.google.gson.*;import java.util.Map;…

idea-控制台输出乱码问题

idea-控制台输出乱码问题 现象描述&#xff1a; 今天在进行IDEA开发WEB工程调式的时候控制台日志输出了乱码&#xff0c;如下截图 其实开发者大多都知道乱码是 编码不一致导致的&#xff0c;但是有时候就是不知到哪些地方不一致&#xff0c;今天我碰到的情况可能和你的不相同…

前端框架Layui实现动态表格效果用户管理实例(对表格进行CRUD操作-附源码)

目录 一、前言 1.什么是表格 2.表格的使用范围 二、案例实现 1.案例分析 ①根据需求找到文档源码 ②查询结果在实体中没有该属性 2.dao层编写 ①BaseDao工具类 ②UserDao编写 3.Servlet编写 ①R工具类的介绍 ②Useraction编写 4.jsp页面搭建 ①userManage.jsp ②…

Docker基础——初识Docker

Docker架构 Docker 使用客户端-服务器 (C/S) 架构模式&#xff0c;使用远程API来管理和创建Docker容器。 Docker 客户端(Client) : Docker 客户端通过命令行或者其他工具使用 Docker SDK (https://docs.docker.com/develop/sdk/) 与 Docker 的守护进程通信。Docker 主机(Host…

搭建微服务工程 【详细步骤】

一、准备阶段 &#x1f349; 本篇文章用到的技术栈 mysqlmybatis[mp]springbootspringcloud alibaba 需要用到的数据库 订单数据库: SET NAMES utf8mb4; SET FOREIGN_KEY_CHECKS 0;-- ---------------------------- -- Table structure for shop_order -- --------------…

windows下配置pytorch + yolov8+vscode,并自定义数据进行训练、摄像头实时预测

最近由于工程需要&#xff0c;研究学习了一下windows下如何配置pytorch和yolov8&#xff0c;并自己搜集数据进行训练和预测&#xff0c;预测使用usb摄像头进行实时预测。在此记录一下全过程 一、软件安装和配置 1. vscode安装 windows平台开发python&#xff0c;我采用vscod…