基于word2vec 和 fast-pytorch-kmeans 的文本聚类实现,利用GPU加速提高聚类速度

news/2024/5/25 10:42:03/文章来源:https://blog.csdn.net/sjxgghg/article/details/136694127

文章目录

    • 简介
      • GPU加速
    • 代码实现
    • kmeans
    • 聚类结果
    • kmeans 绘图函数
    • 相关资料参考

简介

本文使用text2vec模型,把文本转成向量。使用text2vec提供的训练好的模型权重进行文本编码,不重新训练word2vec模型。

直接用训练好的模型权重,方便又快捷

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

GPU加速

传统sklearn的TF-IDF文本转向量,在CPU上计算速度较慢。使用text2vec通过cuda加速,加快文本转向量的速度。
传统使用sklearn的kmeans聚类算法在CPU上计算,如遇到大批量的数据,计算耗时太长。
故本文使用fast-pytorch-kmeans 和 kmeans_pytorch包,基于pytorch在GPU上计算,提高聚类速度。

代码实现

装包

pip install fast-pytorch-kmeans text2vec
import torch
import numpy as npfrom text2vec import SentenceModel

不使用SentenceModel模型也可以,在 text2vec 中,还有很多其他的向量编码模型供选择。

文本编码模型

embedder = SentenceModel()

异常情况说明,该模型需要从huggingface下载模型权重,目前被墙了。(请想办法解决,或者尝试其他的编码模型)
在这里插入图片描述

语料库如下:

# Corpus with example sentences
corpus = ['花呗更改绑定银行卡','我什么时候开通了花呗','A man is eating food.','A man is eating a piece of bread.','The girl is carrying a baby.','A man is riding a horse.','A woman is playing violin.','Two men pushed carts through the woods.','A man is riding a white horse on an enclosed ground.',
]
corpus_embeddings = embedder.encode(corpus)
# numpy 转成 pytorch, 并转移到GPU显存中
corpus_embeddings = torch.from_numpy(corpus_embeddings).to('cuda')

如下图所示,编码的向量是768维;

type(corpus_embeddings), corpus_embeddings.shape

在这里插入图片描述

kmeans

kmeans_pytorch vs fast-pytorch-kmeans:
在实验过程中,利用kmeans_pytorch 针对30万个词进行聚类的时候,发现显存炸了,程序崩溃退出。30万个词的词向量,占用显存还不到2G,但是运行kmeans_pytorch后,显存就炸了。

fast-pytorch-kmeans不存在上述显存崩溃的问题。本以为词向量很多会跑很长时间,但fast-pytorch-kmeans在非常短的时间内就完成了kmeans聚类。
后来一想也理解了,先开始在CPU跑花费了很长时间,这是因为CPU并行很差,需要逐个跑完。而在GPU里大量数据拼成一个矩阵,做一个减法,就可以算出批量节点和中心点的距离。

# kmeans
# from kmeans_pytorch import kmeans
from fast_pytorch_kmeans import KMeansnum_class = 3 # 分类类别数
kmeans = KMeans(n_clusters=num_class, mode='euclidean', verbose=1)# 模型预测结果
labels = kmeans.fit_predict(corpus_embeddings)

聚类程序运行如下:

used 2 iterations (0.3682s) to cluster 9 items into 3 clusters

模型中心点坐标:

kmeans.centroids

在这里插入图片描述

聚类结果

class_data = {i:[]for i in range(3)
}for text,cls in zip(corpus, labels):class_data[cls.item()].append(text)class_data

文本聚类结果如下:
0: 女
1:男
2: 花呗
在这里插入图片描述

kmeans 绘图函数

封装了KMeansPlot 绘图类,方便聚类结果可视化

from sklearn.decomposition import PCA
import matplotlib.pyplot as pltclass KMeansPlot:def __init__(self, numClass=4, func_type='PCA'):if func_type == 'PCA':self.func_plot = PCA(n_components=2)elif func_type == 'TSNE':from sklearn.manifold import TSNEself.func_plot = TSNE(2)self.numClass = numClassdef plot_cluster(self, result, pos, cluster_centers=None):plt.figure(2)Lab = [[] for i in range(self.numClass)]index = 0for labi in result:Lab[labi].append(index)index += 1color = ['oy', 'ob', 'og', 'cs', 'ms', 'bs', 'ks', 'ys', 'yv', 'mv', 'bv', 'kv', 'gv', 'y^', 'm^', 'b^', 'k^','g^'] * 3for i in range(self.numClass):x1 = []y1 = []for ind1 in pos[Lab[i]]:# print ind1try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, color[i])if cluster_centers is not None:#绘制初始中心点x1 = []y1 = []for ind1 in cluster_centers:try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, "rv") #绘制中心plt.show()def plot(self, weight, label, cluster_centers=None):pos = self.func_plot.fit_transform(weight)# 高维的中心点坐标,也经过降维处理cluster_centers = self.func_plot.fit_transform(cluster_centers)self.plot_cluster(list(label), pos, cluster_centers)

kmeans.centroids :是一个高维空间的中心点坐标,故在plot函数中,将其降维到2D平面上;

k_plot = KMeansPlot(num_class)
k_plot.plot(corpus_embeddings.to('cpu'),labels.to('cpu'),kmeans.centroids.to('cpu')
)

在这里插入图片描述

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

相关资料参考

  • 动手实战基于 ML 的中文短文本聚类
  • tfidf和word2vec构建文本词向量并做文本聚类
    提到训练word2vec模型,silhouette_score_show(word2vec, 'word2vec') 轮廓系数,判断分几个类别最好。
  • 机器学习:Kmeans聚类算法总结及GPU配置加速demo
    PyTorch kmeans 加速。from scratch 实现;
  • KMeans算法全面解析与应用案例 通俗易懂的原理讲解
  • pytorch K-means算法的实现 底层代码实现
  • 【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板 使用pytorch封装的kmeans包实现,包括训练和预测;
  • text2vec 包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1007610.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

selenium 网页自动化-在访问一个网页时弹出的浏览器窗口,我该如何处理?

前言 相信大家在使用selenium做网页自动化时,会遇到如下这样的一个场景: 在你使用get访问某一个网址时,会在页面中弹出如上图所示的弹出框。 首先想到是利用Alert类来处理它。 然而,很不幸,Alert类处理的结果就是没…

springboot273基于JavaWeb的宠物商城网站设计与实现

宠物商城网站的设计与实现 摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,商品信息因为其管理内容繁杂,管理数量繁多导…

留学生课设|R语言|研究方法课设

目录 INSTRUCTIONS Question 1. Understanding Quantitative Research Question 2. Inputting data into Jamovi and creating variables (using the dataset) Question 3. Outliers Question 4. Tests for mean difference Question 5. Correlation Analysis INSTRUCTIO…

Elasticsearch:调整近似 kNN 搜索

在我之前的文章 “Elasticsearch:调整搜索速度”,我详细地描述了如何调整正常的 BM25 的搜索速度。在今天的文章里,我们来进一步探讨如何提高近似 kNN 的搜索速度。希望对广大的向量搜索开发者有一些启示。 Elasticsearch 支持近似 k 最近邻…

C#,数值计算,数据测试用的对称正定矩阵(Symmetric Positive Definite Matrix)的随机生成算法与源代码

C.Hermite 1、对称矩阵 对称矩阵(Symmetric Matrices)是指以主对角线为对称轴,各元素对应相等的矩阵。在线性代数中,对称矩阵是一个方形矩阵,其转置矩阵和自身相等。1855年,埃米特(C.Hermite,1822-1901年)证明了别的数学家发现的一些矩阵类的特征根的特殊性质,如称为埃…

Selenium 学习(0.20)——软件测试之单元测试

我又(浪完)回来了…… 很久没有学习了,今天忙完终于想起来学习了。没有学习的这段时间,主要是请了两个事假(5工作日和10工作日)放了个年假(13天),然后就到现在了。 看了下…

15届蓝桥杯第一期模拟赛所有题目解析

文章目录 🧡🧡t1_字母数🧡🧡问题描述思路代码 🧡🧡t2_大乘积🧡🧡问题描述思路代码 🧡🧡t3_星期几🧡🧡问题描述思路代码 🧡…

ctfshow web入门 php特性总结

1.web89 intval函数的利用,intval函数获取变量的整数值,失败时返回0,空的数组返回,非空数组返回1 num[]1 intval ( mixed $var [, int $base 10 ] ) : int Note: 如果 base 是 0,通过检测 var 的格式来决定使用的进…

构建LVS集群

一、集群的基本理论(一)什么是集群 人群或事物聚集:在日常用语中,群集指的是一大群人或事物密集地聚在一起。例如,“人们群集在广场上”,这里的“群集”是指大量人群聚集的现象。 计算机技术中的集群&…

uniapp微信小程序_自定义交费逻辑编写

一、首先看最终效果 先说下整体逻辑,选中状态为淡紫色,点击哪个金额,充值页面上就显示多少金额 二、代码 <view class"addMoney"><view class"addMoneyTittle">充值金额</view><view class"selfaddmoney" :class"{…

【图论】计算图的n-hop邻居个数,并绘制频率分布直方图

计算图的n-hop邻居个数&#xff0c;并绘制频率分布直方图 在图论中&#xff0c;n-hop邻居&#xff08;或称为K-hop邻居&#xff09;是指从某个顶点出发&#xff0c;通过最短路径&#xff08;即最少的边数&#xff09;可以到达的所有顶点的集合&#xff0c;其中n&#xff08;或…

数据的存储底层详解 - 源码、反码、补码 浮点数的存储

数据的存储 1. 前言2. 数据类型2.1 整形家族2.2 浮点数家族2.3 构造类型&#xff08;自定义类型&#xff09;2.4 指针类型2.5 空类型&#xff08;无类型&#xff09; 3. 整数在内存中的存储4. 大小端5. 浮点数在内存中的存储 1. 前言 大家好&#xff0c;我是努力学习游泳的鱼。…

【开源】SpringBoot框架实验室耗材管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 耗材档案模块2.2 耗材入库模块2.3 耗材出库模块2.4 耗材申请模块2.5 耗材审核模块 三、系统展示四、核心代码4.1 查询耗材品类4.2 查询资产出库清单4.3 资产出库4.4 查询入库单4.5 资产入库 五、免责说明 一、摘要 1.1…

3d场景重建图像渲染 | 神经辐射场NeRF(Neural Radiance Fields)

神经辐射场NeRF&#xff08;Neural Radiance Fields&#xff09; 概念 NeRF&#xff08;Neural Radiance Fields&#xff0c;神经辐射场&#xff09;是一种用于3D场景重建和图像渲染的深度学习方法。它由Ben Mildenhall等人在2020年的论文《NeRF: Representing Scenes as Neur…

如何从 Mac 电脑外部硬盘恢复删除的数据文件

本文向您介绍一些恢复 Mac 外置硬盘数据的快速简便的方法。 Mac 的内部存储空间通常不足以存储所有数据。因此&#xff0c;许多用户通过外部驱动器扩展存储或创建数据备份。然而&#xff0c;与几乎所有其他设备一样&#xff0c;从外部硬盘驱动器丢失有价值的数据并不罕见。由于…

pdf也可以制作成可翻页的电子书吗?

​当然可以&#xff01;PDF文件可以通过一些工具和软件转换成可翻页的电子书。这种转换通常需要将PDF文件中的页面重新排列和格式化&#xff0c;以便它们可以像书籍一样翻页。一些流行的工具包括Adobe Acrobat、PDF转换器等 如果需要将大量PDF文件转换为电子书&#xff0c;可以…

【django framework】ModelSerializer+GenericAPIView,如何获取HTTP请求头中的信息(远程IP、UA等)

【django framework】ModelSerializerGenericAPIView&#xff0c;如何获取HTTP请求头中的信息(远程IP、UA等) 某些时候&#xff0c;我们不得不获取调用当前接口的客户端IP、UA等信息&#xff0c;如果是第一次用Django Restframework&#xff0c;可能会有点懵逼&#xff0c;那么…

机械女生,双非本985硕,目前学了C 基础知识,转嵌入式还是java更好?

作为单片机项目开发的卖课佬&#xff0c;个人建议&#xff0c;先转嵌入式单片机开发方向&#xff0c;哈哈。 java我也学过&#xff0c;还学过oracle、mysql数据库&#xff0c;只是当时没做笔记&#xff0c;找不好充分的装逼证据了。 从实习通过业余时间&#xff0c;学到快正式毕…

微信小程序 uniapp奶茶点单系统r4112

系统功能有&#xff1a;信点单小程序分为小程序部分和后台管理两部分&#xff0c;小程序部分的主要功能包含&#xff1a;用户注册登录&#xff0c;查看商品信息&#xff0c;加入购物车&#xff0c;结算并生成订单&#xff0c;订单管理&#xff0c;资讯管理&#xff0c;个人中心…

前端框架vue的样式操作,以及vue提供的属性功能应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…