基于DQN的强化学习 快速浏览(基础知识+示例代码)

news/2024/5/7 13:34:42/文章来源:https://blog.csdn.net/m0_37412775/article/details/128012232

一、强化学习的基础概念

强化学习中有2个主要的实体,一个是智能体(agent),另一个是环境(environment)。在强化学习过程中,智能体能够得到的是环境当前的状态(State),即环境智能体所处环境当前的情况。另一个是上一步获得的环境的奖励(Reward),即环境给予智能体动作的一个反馈。智能体根据这两个信息,决定在环境中采取的动作(Action),以及环境接收智能体的动作,返回下一步的状态和对智能体的奖励。整个过程可以归纳为:在t时刻,给定该时刻的状态s_t和获得的奖励r_t,根据这些值来决定当前步骤的动作a_t,将动作转递给环境,得到下一个时刻的状态s_{t+1}和获得的奖励r_{t+1}

在强化学习中需要解决的问题是如何训练一个智能体,使得智能体能够在合适的状态下产生合适的动作,使后续的奖励总和最大。我们称智能体根据环境状态产生动作的方法为策略(policy),在这种情况下,强化学习可以归结为寻找一个最优策略,使得未来能够获得的奖励最大。

强化学习有很多分类,其中根据深度学习模型描述的是策略本身还是在当前状态下未来能够获得的奖励,可分为两种,前者称为策略优化(policy optimization),后者称为质量函数学习(Q-learning,这里Q即为Quality,质量)。假设强化学习的过程是一个马尔可夫过程,即未来获得的奖励和过去的历史无关,则质量函数是当前状态的一个函数,可根据Bellman方程对其进行描述。假设奖励的折扣率为\gamma (0<\gamma<1),设置这个参数的目的是因为我们需要对未来的奖励进行求和,在时间跨度上无限大,如果没有折扣率,未来总的奖励可能是无穷大。在这种情况下,我们可以计算总的折扣后的奖励,如式(1)所示。

(1)

由于式(1)中的R_t,即未来奖励加权求和的结果和当前的状态及采取的动作有关,当引入t时刻的状态s_t和采取的动作a_t后,可以根据Bellman方程求得对应的折扣奖励关于状态和动作的函数,称之为质量函数Q,如式(2)。

(2)

二、强化学习的软件环境Gym安装

式(2)是进行强化学习的基础,用深度学习来学习质量函数的算法被称为DQN,这里以一个简单的强化学习例子来阐述如何使用DQN进行强化学习的任务。首先使用OpenAI的Gym工具集来构造强化学习环境。Gym工具集有很多场景,包括经典控制环境、雅达利游戏、二维和三维机器人环境等。

具体安装方式如下:

# 方式1:
pip install gym
# 方式2:
git clone https://github.com/openai/gym
pip install -e .

三、车杆环境介绍

本博文以gym的车杆环境为例,进行DQN的模型搭建和训练。

车杆(Cartpole)环境介绍如下:

车杆环境由一个可以自由转动的杆子连接一个可以水平运动的小车构成。通过向环境发送“左移”或者“右移”控制小车移动。每次发生移动指令之后,环境会返回一个数组来表示小车当前的运动状态,这个数组包括:小车当前的位置(-4.8~4.8)、小车当前的速度(负无穷到正无穷)、杆子当前的角度(-24~24)、杆子顶端的速度(负无穷到正无穷)来表示。环境还会返回一个奖励值,该值在杆子的角度在-15~15之间且小车位置在-2.4~2.4之间时为1(最多持续200步),其他状态下为0,且强化学习段落(episode)会在下一步终止。

我们需要训练的模型就是根据的状态数组做动态调整,让小车上的杆子能够保持在一定的角度范围内,且小车的位置也能保持在一定距离范围内,从而最终达到奖励值最大的目的。

 【注】:这里提到的强化学习段落episode指的是:All states that come in between an initial-state and a terminal-state; for example: one game of Chess. The agents goal it to maximize the total reward it reward it receives during an episode. In situations where there is no terminal-state, we consider an infinite episode. It is important to remember that different episodes are completely independent of one another. 大概意思是:一个episode即为一轮博弈,智能体从最开始的状态到某一个终止状态为一个episode,这个过程是有状态集、行为集、奖励等组成的一个完成序列。且各个episode是完全独立的。

四、车杆(Cartpole)环境使用

首先介绍如何使用Cartpole环境。代码及对应的解释如下:

import gym
env = gym.make('CartPole-v0') # 通过gym.make创建了一个CartPole环境env
for i_episode in range(20): # 共运行了20个强化学习段落observation = env.reset()for t in range(100): # 每个段落最大100步env.render()print(observation)action = env.action_space.sample() # 对动作进行随机采样(在CartPole环境下只有0和1两种动作)state, reward, terminated, truncated, info = env.step(action) # 执行动作,获取环境的反馈。state指执行了该动作后环境的状态,reward指当前动作获取的奖励,terminated表示当前段落是否结束;truncated通常指是否超出时间限制,info指环境的其他信息if terminated:print("Episode finished after {} timesteps".format(t+1))break
env.close()

这里再强调一遍gym中常用的代码接口及其含义:

代码接口含义
env = gym.make("XXX").env进入指定的实验环境
env.render()渲染环境,即可视化看看环境的样子
env.reset重置环境,返回一个随机的初始状态
env.step(action)将选择的action输入给env,env 按照这个动作走一步进入下一个状态
env.step(action)的返回值

state:执行action后的状态

reward:执行action的奖励

terminated:whether a terminal state (as defined under the MDP of the task) is reached. In this case further step() calls could return undefined results.

truncated:whether a truncation condition outside the scope of the MDP is satisfied. Typically a timelimit, but could also be used to indicate agent physically going out of bounds. Can be used to end the episode prematurely before a terminal state is reached.

info:其他信息

env.render()渲染出当前的智能体以及环境的状态,用于可视化

以上代码会让杆子很快偏离平衡位置,导致强化学习段落结束,为了能让杆子稳定,需要DQN对每一个env.step选择具体的动作。

五、QDN模型的搭建

先构建质量函数的深度学习模型,具体代码及解释如下:

import torch
import torch.nn as nn
class DQN(nn.Module):def __init__(self, naction, nstate, nhidden):super(DQN, self).__init__()self.naction = naction # 候选的动作总数。在CartPole中,该值为2self.nstate = nstate # 状态的维度数。在CartPole中,该值为4self.linear1 = nn.Linear(naction + nstate, nhidden)self.linear2 = nn.Linear(nhidden, nhidden)self.linear3 = nn.Linear(nhidden, 1)def forward(self, state, action):action_enc = torch.zeros(action.size(0), self.naction)action_enc.scatter_(1, action.unsqueeze(-1),1)output = torch.cat((state, action_enc), dim=-1)output = torch.relu(self.linear1(output))output = torch.relu(self.linear2(output))output = self.linear3(output)return output.squeeze(-1)

为了加强模型训练的收敛,在DQN算法的训练中需要用到重放(replay)技巧,通过反复播放强化学习的历史记录来加强模型的训练。这里用一个记忆类类记录训练历史,代码及解释如下:

import random
class Memory(object):def __init__(self,capacity=1000):self.capacity = capacity # 记忆的长短self.size = 0self.data = []def __len__(self):return self.sizedef push(self, state, action, state_next, reward, is_ended): # 向记忆类中放入单步训练的记录# state当前状态,action采取的动作,state_next下一步的状态,reward状态的奖励,is_ended下一步状态是否为最终状态if len(self) > self.capacity:k = random.randint(self.capacity)self.data.pop(k)self.size -= 1self.data.append((state, action, state_next, reward, is_ended))def sample(self, bs): # 获取一定迷你批次bs的历史数据进行重放,通过重放数据进行学习。data = random.choices(self.data, k = bs)states, actions, states_next, rewards, is_ended = zip(*data)states = torch.tensor(states, dtype=torch.float32)actions = torch.tensor(actions)states_next = torch.tensor(states_next, dtype=torch.float32)rewards =  torch.tensor(rewards, dtype=torch.float32)is_ended =  torch.tensor(is_ended, dtype=torch.float32)return states, actions, states_next, rewards, is_ended

六、DQN模型的训练

有了基础模型和重放类后,接下来对模型进行训练,代码及解释如下:

import copy
import torch.nn
# 定义2个网络,用于加速模型收敛
dqn = DQN(2,4,8) # 主要优化它
dqn_t = DQN(2,4,8) # 用来辅助dqn的模型优化,增强dqn的数值稳定性,加速模型收敛
dqn_t.load_state_dict(copy.deepcopy(dqn.state_dict()))
eps = 0.1 # 定义强化学习模型的探索比例。探索:对动作空间的随机采样达到遍历动作空间的目的,保证探索数量;利用:使用模型并选择模型的最优动作和环境进行交互,防止由重复探索出现。
# 折扣系数
gamma = 0.999
optim = torch.optim.Adam(dqn.parameters(), lr=1e-3)
criterion = torch.nn.HuberLoss() # 需要torch1.9.0及以上的版本
step_cnt= 0
mem = Memory()for episode in range(300):state = env.reset()while True:action_t = torch.tensor([0,1])state_t = torch.tensor([state, state], dtype=torch.float32)# 计算最优策略torch.set_grad_enabled(False)q_t = dqn(state_t, action_t)max_t = q_t.argmax()torch.set_grad_enabled(True)# 探索和利用的平衡if random.random() < eps:max_t = random.choice([0,1])else:max_t = max_t.iem()state_next, reward, done, truncated, info = env.step(max_t)mem.push(state, max_t, state_next, reward, done)state = state_nextif done:break# 重放训练for _ in range(10):state_t, action_t, state_next_t, reward_t, is_ended_t = mem.sample(32)q1 = dqn(state_t, action_t)torch.set_grad_enabled(False)q2_0 = dqn_t(state_next_t, torch.zeros(state_t.size(0), dtype=torch.long))q2_1 = dqn_t(state_next_t, torch.ones(state_t.size(0), dtype=torch.long))# 利用Bellman方程进行迭代q2_max = reward_t + gamma*(1-is_ended_t)*(torch.stack((q2_0, q2_1), dim=1).max(1)[0])torch.set_grad_enabled(True)# 优化损失函数delta = q2_max - q1loss = criterion(delta)optim.zero_grad()loss.backward()for p in dqn.parameters():p.grad.data.clamp_(-1,1)optim.step()step_cnt += 1# 同步2个网络的参数if step_cnt % 1000 == 0:dqn_t.load_state_dict(copy.deepcopy((dqn.state_dict())))env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_224940.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端框架 Nuxtjs Vue3 SEO解决方案 SSR

目录 一、Nuxtjs安装 二、路由规则 三、公共布局 四、Vue3中TypeScript的使用 一、Nuxtjs安装 参考&#xff1a;Installation Get Started with Nuxt安装 - NuxtJS | Nuxt.js 中文网Installation Get Started with Nuxt yarn create nuxt-app <项目名> 项目运行…

WebDAV之葫芦儿·派盘+i简记

i简记 支持webdav方式连接葫芦儿派盘。 每天都去记录生活中所消费的琐碎开支,不仅浪费时间,还有很多广告和理财推销。那有没有纯粹的手机在线记账工具?可以轻松把微信、支付宝账单导入,支持外账入内,还有汇率转换等?答案是肯定的,i简记就是非常实用的在线记账工具。 i…

后端进阶知识 Buffer pool 图文详解 之 free链表

Buffer pool 图文详解 之 free 链表 前言数据页缓存页描述信息初始化 Buffer poolfree 链表获取空闲页数据页是否缓存可关注专栏 》MySQL 进阶知识 收藏点赞加关注 前言 Buffer pool 是 InnerDB 存储引擎的一个重要组件&#xff0c;MySQL 的所有 CRUD 操作都是围绕 Buffer poo…

Vue3+nodejs全栈项目(资金管理系统)——后端篇(一)登录、注册

文章目录初始化创建项目配置跨域配置解析表单数据的中间件安装bodyparser初始化用户路由模块抽离用户路由模块中的处理函数登录注册新建admin表安装并配置mysql模块注册检测表单数据是否合法检测用户名是否被占用对密码进行加密处理bcryptjs插入新用户测试登录根据名字查询用户…

Java基础知识+必考面试题(分享收藏版)

在学习Java语言之前&#xff0c;我们要了解相关知识体系&#xff0c;才能更好的掌握学习。那么下面我们就一起来学习JAVA语言吧~ Java语言概述 Java语言是Sun公司在1995年推出的高级编程语言&#xff0c;编程语言就是计算机语言&#xff0c;人们可以通过使用编程语言让计算机完…

红黑树C++实现

目录 一、红黑树的概念 二、红黑树的性质 三、红黑树节点的定义 四、红黑树的插入 4.1 插入节点 4.2 插入节点的颜色 4.3 调整情况1 4.4 调整情况2 4.5 调整情况3 4.6 调整情况总结 五、调整的实现 5.1 调整的步骤分析 5.2 代码实现 六、树的平衡判断 七、源代码…

一文让你了解数据采集

随着云计算、大数据、人工智能的发展&#xff0c;数据采集作为数据的重要手段&#xff0c;成为广大企业的迫切需求。 所谓“得数据者&#xff0c;得人工智能”&#xff0c;如今人工智能早已在我们的生活中屡见不鲜。如“人脸识别”、“语音唤醒音响”等都属于人工智能的范畴。…

擎创技术流 | ClickHouse实用工具—ckman教程(4)

《使用CKman导入集群》 CKman&#xff08;ClickHouse Manager&#xff09;是由擎创科技自主研发的一款管理ClickHouse的工具&#xff0c;前端用Vue框架&#xff0c;后端使用Go语言编写。它主要用来管理ClickHouse集群、节点以及数据监控等&#xff0c;致力于服务ClickHouse分布…

总结我的 MySQL 学习历程,给有需要的人看

作者| 慕课网精英讲师 马听 你好&#xff0c;我是马听&#xff0c;现在是某零售公司的 MySQL DBA&#xff0c;身处一线的我表示有很多话要讲。 我的MySQL学习历程 在我大三的时候&#xff0c;就开始接触到 MySQL 了&#xff0c;当时我也是从最基础的 MySQL 知识&#xff08;…

如何实现办公自动化?

办公自动化&#xff08;OA&#xff09;允许数据在没有人工干预的情况下流动。由于人工操作被排除在外&#xff0c;所以没有人为错误的风险。如今&#xff0c;办公自动化已经发展成无数的自动化和电子工具&#xff0c;改变了人们的工作方式。 办公自动化的好处 企业或多或少依…

[附源码]计算机毕业设计JAVA领导干部听课评课管理系统

[附源码]计算机毕业设计JAVA领导干部听课评课管理系统 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM…

【LeetCode与《代码随想录》】哈希表篇:做题笔记与总结-JavaScript版

文章目录代码随想录主要题目242. 有效的字母异位词349. 两个数组的交集202. 快乐数1. 两数之和&#xff08;经典哈希&#xff09;454. 四数相加 II15. 三数之和&#xff08;双指针&#xff09;18. 四数之和&#xff08;双指针&#xff09;相关题目383. 赎金信49. 字母异位词分组…

Python遥感开发之GDAL读写遥感影像

Python遥感开发之GDAL读写遥感影像1 读取tif信息方法一2 读取tif信息方法二3 自己封装读取tif的方法&#xff08;推荐&#xff09;4 对读取的tif数据进行简单运算5 写出tif影像(推荐)前言&#xff1a;主要介绍了使用GDAL读写遥感影像数据的操作&#xff0c;包括读取行、列、投影…

从零学习 InfiniBand-network架构(八) —— IB协议中的原子操作

从零学习 InfiniBand-network架构&#xff08;八&#xff09; —— IB协议中的原子操作 &#x1f508;声明&#xff1a; &#x1f603;博主主页&#xff1a;王_嘻嘻的CSDN主页 &#x1f511;未经作者允许&#xff0c;禁止转载 &#x1f6a9;本专题部分内容源于《InfiniBand-net…

Docker——容器命令介绍、创建Nginx容器与Redis容器

目录 一、容器命令 二、创建并运行Nginx容器 1.1 去dockerhub查看Nginx容器运行命令 1.2 怎么访问Nginx&#xff1f; 1.3 查看容器日志 1.4总结 三、进入Nginx容器并修改HTML内容 3.1 进入容器 3.2 进入Nginx的HTML所在目录 3.3 修改index.html文件&#xff08;容器内修…

【OpenEVSE 】汽车充电桩控制项目解析

【OpenEVSE 】汽车充电桩控制项目解析1. 项目介绍2. 项目硬件3. 软件原理以及流程4. 系统结构&#xff1a;ESP32RAPI APIMQTT 上的 RAPI:5. SAE J1772协议简析&#xff1a;6. 专用充电接插件7 . 源码解析&#xff1a;此项目来源于openEnergyMonitor 的 openEVSE 部分&#xff0…

查阅必备----常用的SQL语句,配语句和图解超详细,不怕你忘记

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 秩沅 原创 **收录于专栏 数据库 ⭐查阅必备–常用的SQL语句⭐ 文章目录⭐查阅必备--常用的SQL语句⭐一&#xff0c;关键语句大全&am…

python离线安装module以及常见问题及解决方案

文章目录一&#xff0c;离线安装module1.1 下载module1.2 离线安装二&#xff0c;常见的问题2.1 模块缺少合适的适配&#xff1a;error: Could not find suitable distribution for Requirement.parse()2.2 install成功但发现控制台打印的最后一行显示下载module版本为0.0.0工作…

微信商城小程序怎么开发_分享微信商城小程序的搭建

如何搭建好一个微信商城&#xff1f;这三个功能要会用&#xff01; 1.定期低价秒杀&#xff0c;提高商城流量 除了通过私域流量裂变&#xff0c;低价秒杀是为商城引流提高打开率的良好手段。 以不同节日作为嘘头&#xff0c;在情人节、38妇女节、中秋国庆、七夕节等日子&…

机器学习-回归模型相关重要知识点

目录01 线性回归的假设是什么&#xff1f;02 什么是残差&#xff0c;它如何用于评估回归模型&#xff1f;03 如何区分线性回归模型和非线性回归模型&#xff1f;04 什么是多重共线性&#xff0c;它如何影响模型性能&#xff1f;05 异常值如何影响线性回归模型的性能&#xff1f…