AlphaZero强化学习模型

news/2024/5/21 0:35:26/文章来源:https://blog.csdn.net/qq_29788741/article/details/127274951

搬来了DeepMind的AlphaTensor

DeepMind前不久发在Nature上的论文Discovering faster matrix multiplication algorithms with reinforcement learning引发热议。

这篇论文在德国数学家Volken Strassen「用加法换乘法」思路和算法的基础上,构建了一个基于AlphaZero的强化学习模型,更高效地探索进一步提高矩阵乘法速度的通用方法。

视频发不了就算了哈

基本思路:用加法换乘法

众所周知,矩阵乘法的传统算法是:两个矩阵行列交换相乘,然后求和,作为新矩阵的对应元素。其中涉及到大量的加法和乘法运算。

对于计算机来说,运算加法的速度要远远快于乘法,所以提升运算速度的关键,就是尽量减少乘法运算的次数,即使为此增加加法运算次数,对于计算加速的效果也是非常明显的。

遵循这个「用加法换乘法」的基本思路,德国数学家Volken Strassen于1969年发现了更高效、占用计算资源更少的矩阵乘法算法。

实际上,这个思路在一些最基础的数学公式中就已经有充分体现。比如平方差公式:

a^2-b^2 =(a+b)*(a-b)

等号左侧计算两次乘法、一次加法,等号右侧计算一次乘法、两次加法。实际上,如果按照多项式乘法对等号右侧展开,实际上发生了正负ab的消去,将乘法运算的次数从4次降低为2次。

Strassen的算法是,利用原矩阵构造一些加乘结合的中间量,每个中间量只包含一次乘法计算,将原矩阵乘法转换为这些中间量的加法运算,将一些符号相反的乘法消去,实现降低乘法运算次数的目的。

在2*2矩阵的乘法中,Strassen的算法将乘法运算次数由8次降为7次。

矩阵乘法的张量表示和低秩分解

那么下一个问题就是,如何找到一种算法,构建能够消去乘法运算的中间量,同时更方便地利用强化学习技术?

DeepMind给出的答案是:将矩阵乘法转换为「低秩分解」问题。

同样以2*2矩阵为例,使用三维张量来表示 AB=C 的矩阵乘法运算过程,其中左右维度(列)为A,上下维度(行)为B,前后维度(深)为C。

用{0,1}对这个表示张量进行填充。C中取到值的部分,填充为1,其余填充为0。如下图所示。

 比如,c1=a1*b1+a2*b3,在「最深一层」所表示的c1上,可以看到左上方(第1行第1列)的a1b1,和第3行第2列的a2b3被表示为紫色1,其余为白色0。

在张量表示后,可以通过对矩阵的「低秩分解」,设张量Tn为两个 n×n 矩阵相乘的表示张量。将Tn分解为r个秩一项(rank-one term)的外积。

两个n维向量的外积可以得到一个n×n的矩阵,三个n维向量的外积可以得到一个 n×n×n 的张量。

仍以Strassen的算法为例,低秩分解后的结果,即上式中的U、V、W对应为3个7秩矩阵。这里的分解矩阵的秩决定原矩阵乘法中乘法运算的次数。

实际上,用这个方法可以将n×n矩阵乘法的计算复杂度降低至 O(Nlogn(R)) 。

由此可以设计一种规则,一一对应地得到图(b)中的矩阵乘法算法,即论文中的「算法1」:

 

建模:基于强化学习的AlphaTensor

DeepMind利用强化学习训练了一个AlphaTensor智能体来玩一个单人游戏(Tensor Game),开始时没有任何关于现有矩阵乘法算法的知识。

 

 这个强化学习模型正是基于此前的AI围棋大师AlphaZero。

那么这个游戏要如何设计,才能将其与矩阵乘法的简化建立联系,从而解决实际问题呢?

应用AlphaZero时,作者有一些特殊的网络架构技巧。

他们使用了线性代数的某些属性,比如,即使我们改变了线性运算的某些基础,问题也是同样的。因此,即使我们改变了矩阵的基础,它在本质上仍然代表同样的转换。

然而,对于这个算法来说,却不是这样的。

有了不同的数字,算法看起来就不同了,因为它是一种对彼此的转换。在这里,作者就很好地利用了线性代数的基本属性,创建出了更多的训练数据。

另外,分解3D张量看起来很难,但创造一个3D张量,就很容易。

 

 我们只需对添加的3个向量进行采样,把它们加在一起,就有了一个三维张量。经过正确的分解,它们还可以创建合成训练数据。

这些技巧都非常聪明,提供了更多的数据给系统。系统经过训练,可以准确地提供这些分解。

让我们分析一下神经网络架构,它是一个基于Transformer的网络。

本质上,它是一个强化学习算法。

首先要输入当前的张量以及张量的历史,接着是躯干(Torso),然后是嵌入(Embedding),最后是Policy Head和Value Head。

 在上图所指的位置,我们要选择三个向量u,v,w,进行相应计算。

一旦我们有三个向量的动作,我们就可以从原始张量中减去它。然后的目标是,找到从原始张量中减去的下一个动作。所有张量的Entry都是0的时候,游戏正好结束。 

 

这显然是一个离散问题。如果张量的阶数高于2,就属于NP hard。

这个任务实际上很艰巨,我们使用的是3个向量,每个向量都有对应的Entry,因此这是一个巨大的动作空间,比国际象棋或围棋之类的空间都大得多,因此也困难得多。

 

这是一个更精细的架构图。他们把最后一个时间步中出现的张量的历史,用各种方式把投影到这个网格层上,然后线性层Grid 2将其转换为某种C维向量(这里时间维度就减少了)。

 在这里,我们输出一个策略,这个策略是我们动作空间上的一个分布,还有一个输出到Value Head。 

 Value Head是从Policy Head中获取嵌入,然后通过一些神经网络推动。

要点就是,将网络与蒙特卡洛树搜索匹配。

总结一下:为了解决这些游戏,开始,我们的矩阵是满的,棋盘处于初始状态,然后就要考虑不同的动作,每一步动作都会包含更多的动作,包括你的对手可能考虑到的动作。

这其实就是一个树搜索算法。现在Alpha Zero style的蒙特卡洛树搜索,就是通过神经网络的策略和价值函数,引导我们完成这个树搜索。

它在用蓝线圈出的节点,就会向你提出建议,让你获得更成功的张量分解,也就是说,让你有更高的机率获胜。并且,它会直接排除掉你不该尝试的步骤,缩小你的考虑范围。

你只需要搜索,然后通过迭代训练,在某个节点,得到Zero Tensor,就意味着你胜利了。

没有完成游戏的话,奖励就非常低,反馈到训练神经网络之后,会做出更好的预测。

实际上,奖励不止是0或1, 为了鼓励模型发现最短路径,  作者还设定了一个-1的奖励。

这就比只给0或1的奖励好得多,因为它鼓励了低阶的分解,还提供了更密集的奖励信号。

因为问题很难,胜利具有很高的偶然性,奖励是稀少的。而如果走每一步都会得到奖励,也有可能是-1的奖励,就会敦促模型采取更少的步骤。

更重要的是,在这个合成演示中,他们会匹配一个监督奖励。

因为作者不仅可以生成数据,他们实际上是知道正确的步骤的,所以他们可以以监督的方式训练神经网络——因为是我们提出的问题,所以我们已经知道你该采取哪些步骤了。

再回顾一下整个算法。

 

针对原始游戏,作者改变了basis,将数据增强,然后进行蒙特卡洛树搜索。几个树搜索之后,游戏结束,根据结果的输赢,会得到相应的奖励,然后来训练。

把它放在游戏缓冲区,就可以更好地预测要执行的操作。

Policy Head会指导你走哪条路,在某个节点,你可以问Value Head:现在的状态值是多少?把所有内容汇总到顶部,选择最有希望的步骤。这就是MCTS Alpha Zero style的简介。

作者的另一个巧思是:除了-1的奖励,还在终端提供额外的奖励。如果算法在英伟达V100或TPUv2上运行得很快,还会得到额外的奖励。 

AlphaTensor当然不知道V100是什么,但通过强化学习的力量,我们就可以找到在特定硬件上速度非常快的算法。 

这样,我们就可以让算法提出定制的解决方案。

不仅是矩阵乘法,编译器也是这种原理。我们可以用这种方法,为特定的硬件优化速度、内存等。显然,它的应用领域已经远远超出了矩阵乘法。  whaosoft aiot http://143ai.com

对于数学的变革

作者还发现,对于两个四乘四矩阵相乘的得到的T4,AlphaTensor发现了超过14,000个非等价分解。

每种大小的矩阵乘法算法多达数千种,表明矩阵乘法算法的空间比以前想象的要丰富。

对于关心复杂性理论的数学家来说,这是一个巨大的发现。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_22237.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[GWCTF 2019]我有一个数据库

打开题目是乱码,好奇怪 御剑扫一下 扫到了phpmyadmin 版本为4.8.1 这个版本是有漏洞的(CVE-2018-12613),复现一下 部分源码: $target_blacklist array (import.php, export.php ); ​ // If we have a valid target…

SpringBoot统一处理返回格式

在对接第三方接口的时候,第三方接口返回格式形式为{"result":null,"status":1010},虽然返回了状态码,但是状态码对应的描述信息并没有携带,前端在使用的时候需要根据状态码返回一个友好的提示,如此…

刘韧:我每时每刻都会注意管理自己的知识

1. 担心总能让我积极行动起来。2. 要提早主动求变,不要等到被迫地、见招拆招地应变。3. 很多愚蠢的念头,都源于自己分内的事,却老想让别人负责,比如将自己的愿望寄托在子女身上。4. 推卸责任的同时,多少会对等地给予一…

ShardingSphere 5.2.0:分片审计功能拦截多分片场景下的不合理请求

一、背景Apache ShardingSphere 基于用户的实际使用场景,为用户打造了多种实用功能,包括数据分片、读写分离等。在数据分片功能中,我们发现有些用户涉及到的分片较多,一个分片逻辑表可能对应后端 1000 个物理表,这给用…

猿创征文 | 国产数据库实战之TiDB 数据库快速入门

猿创征文 | 国产数据库实战之TiDB 数据库快速入门一、系统检查1.检查系统版本2.查看本地IP地址3.TiDB集群介绍二、快速部署本地测试集群1.安装 TiUP工具2.声明全局环境变量3.快速部署TiDB 集群三、连接 TiDB 数据库1.新开一个session 以访问 TiDB 数据库2.通过Mysql客户端连接T…

SpringSecurity整合JWT+Oauth2认证

没写完&#xff0c;推荐下面的博客 推荐博客<查看pom依赖、数据库sql、实体类、数据映射>&#xff1a;SpringSecurity框架 推荐博客<查看SpringSecurity整合JWTOauth2认证>&#xff1a;SpringSecurity整合JWTOauth2认证 一 创建项目 测试浏览器&#xff1a;建议使用…

网课查题系统-公众号轻松调用方法

网课查题系统-公众号轻松调用方法 本平台优点&#xff1a; 多题库查题、独立后台、响应速度快、全网平台可查、功能最全&#xff01; 1.想要给自己的公众号获得查题接口&#xff0c;只需要两步&#xff01; 2.题库&#xff1a; 查题校园题库&#xff1a;查题校园题库后台&am…

Django ORM F对象和Q对象查询

Django ORM F对象和Q对象查询1.F对象查询2.Q对象查询Django提供了两个非常有用的工具&#xff1a;F对象和Q对象&#xff0c;方便了在一些特殊场景下的查询过程。 1.F对象查询 F对象用于操作数据库中某一列的值&#xff0c;它可以在没有实际访问数据库获取数据值的情况下对字段…

史上最简SLAM零基础解读(7) - Jacobian matrix(雅可比矩阵) → 理论分析与应用详解(Bundle Adjustment)

本人讲解关于slam一系列文章汇总链接:史上最全slam从零开始 文末正下方中心提供了本人联系方式&#xff0c;点击本人照片即可显示WX→官方认证{\color{blue}{文末正下方中心}提供了本人 \color{red} 联系方式&#xff0c;\color{blue}点击本人照片即可显示WX→官方认证}文末正…

基于微信小程序的毕业设计题目(23)php汽车维修保养小程序(含开题报告、任务书、中期报告、答辩PPT、论文模板)

项目背景和意义 目的&#xff1a;本课题主要目标是设计并能够实现一个基于微信汽车维修保养小程序系统&#xff0c;前台用户使用小程序&#xff0c;小程序使用微信开发者工具开发&#xff1b;后台管理使用基PPMySql的B/S架构&#xff0c;开发工具使用phpstorm&#xff1b;通过后…

毕业设计 单片机stm32智能路灯智能灯控系统 - LoRa远程通信

文章目录0 前言1 简介2 主要器件3 实现效果4 设计原理4.1 Lora模块4.2 DHT11温湿度传感器4.3 光照传感器5 部分核心代码6 最后0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩…

springboot使用布隆过滤器——缓存穿透

目录 1.布隆过滤器原理 2.具体使用场景 3.springboot集成布隆过滤器 4.总结 1.布隆过滤器原理 布隆过滤器&#xff08;Bloom Filter&#xff09;是非常经典的以空间换时间的算法。它实际上是一个很长的二进制向量和一系列随机映射函数。布隆过滤器可以用于检索一个元素是否…

虹科分享 | 什么是深度数据包检测(DPI)

深度数据包检测 (DPI) 是一种分析通过网络发送的流量的高级方法。DPI 使用数据处理来检查数据包的特定细节&#xff0c;作为数据包过滤的一种形式。 虽然 DPI 用于查看 OSI 模型的第 2-7 层&#xff0c;但仅当设备可以查看并根据第 3 层或更高层采取行动时&#xff0c;它才被视…

rsync+inotify实时同步

查看主页俩篇 inotify、rsync 编写脚本实现inotify与rsync相结合 客户端 #!/bin/bash Path/root/rsync_data backup_Server192.168.80.132 /usr/bin/inotifywait -mrq --format %w%f -e create,close_write,delete $Path | while read line do if [ -f $line ];then rsync -…

7个最佳WordPress设计师和摄影师作品插件

您是一名设计师或摄影师&#xff0c;正在寻找在 WordPress 中构建作品网站的最简单方法吗&#xff1f; 微信扫描二维码用手机阅读或收藏 有很多WordPress作品插件&#xff0c;可让您轻松构建漂亮的作品网站。 但是&#xff0c;对于初学者来说&#xff0c;找到完美的作品插件插…

从深圳寄东西到加拿大,用什么快递比较好?

哪家快递好这个是没有定论的&#xff0c;合适自己的渠道才是好渠道&#xff0c;通常情况下&#xff0c;四大快递和EMS这些基本都是没什么大问题的。下面方联国际物流就来带大家了解一下从深圳寄东西到加拿大的几种主要方式。目前有4种方式运输到加拿大&#xff1a;专线、快递、…

MaxCompute 笛卡尔积逻辑的参数优化复杂JOIN逻辑优化

1. 优化概述 最近协助一个项目做下优化任务的工作。因为主要数据都是报表&#xff0c;对数对的昏天暗地的不敢随便调整SQL逻辑&#xff0c;所以本身只想做点参数调整&#xff0c;但是逼不得已后来还是改了一下SQL。 这篇文章主要讲一个SQL优化反映的两个优化点。分别是&#…

AC 自动机算法介绍

一 点睛 AC 自动机是 KMP 算法和 Trie 树的结合&#xff0c;是经典的多模匹配算法。首先将多个模式串构建一棵字典树&#xff0c;然后在字典树上添加失配指针&#xff0c;失配指针相当于 KMP 算法中的 next 函数&#xff08;匹配失败时的回退位置&#xff09;&#xff0c;最后…

用python实现猜数字游戏

实现思路电脑随机生成1~100的整数,让用户去猜,用户每猜一次程序都会做出相应的提示。若用户输入所猜的数字小于电脑随机生成的数字,则提示“你猜小了”;若大于,则提示“你猜大了”;若等于,则提示“恭喜你赢了”(一直猜直到猜对游戏结束也可以控制猜的次数)这里需要用到p…

公众号如何搭建查题功能-拥有单独的后台

公众号如何搭建查题功能-拥有单独的后台 本平台优点&#xff1a; 多题库查题、独立后台、响应速度快、全网平台可查、功能最全&#xff01; 1.想要给自己的公众号获得查题接口&#xff0c;只需要两步&#xff01; 2.题库&#xff1a; 题库&#xff1a;题库后台&#xff08;点…