环境如下:
这是一个简单的环境,绿色方块代表终点,白色方块代表可行点,灰色方块代表陷阱
用Sarsa算法
和Q_learning算法
训练得到value表格
代码如下:
(jupyter notebook上的代码,所以顺序看起来有点儿奇怪)
def get_state(row,col):if row!=3:return 'ground'elif col==0:return 'ground'elif col==11:return 'terminal'else:return 'trap'
def envirment(row,col,action):if action==0:row-=1elif action==1:row+=1elif action==2:col-=1elif action==3:col+=1next_row=min(max(0,row),3)next_col=min(max(0,col),11)reward=-1if get_state(next_row,next_col)=='trap':reward=-100elif get_state(next_row,next_col)=='terminal':reward=100return next_row,next_col,reward
import numpy as np
import random
Q_pi=np.zeros([4,12,4])
def get_action(row,col):#获取下一步的动作if random.random()<0.1:return random.choice(range(4))#随机选一个动作else:return Q_pi[row,col].argmax()#选择Q_pi大的动作
def TD_sarsa(row,col,action,reward,next_row,next_col,next_action):TD_target=reward+0.9*Q_pi[next_row,next_col,next_action] #sarsa
# TD_target=reward+0.9*Q_pi[next_row,next_col].max()#Q_learnTD_error=Q_pi[row,col,action]-TD_targetreturn TD_errordef train():for epoch in range(3000):row = random.choice(range(4))col = 0action = get_action(row, col)reward_sum = 0# print(action)while get_state(row, col) not in ['terminal', 'trap']:next_row, next_col, reward = envirment(row, col, action)reward_sum += reward# print(row,col,next_row,next_col)next_action = get_action(next_row, next_col)TD_error = TD_sarsa(row, col, action, reward, next_row, next_col, next_action) # Q_learn时可以少传一个变量next_actioQ_pi[row, col, action] -= 0.1 * TD_error# print(row,col,next_row,next_col)row = next_rowcol = next_colaction = next_action# print("epoch")if epoch % 150 == 0:print(epoch, reward_sum)train()
#打印游戏,方便测试
def show(row, col, action):graph = ['□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□','□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□','□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○','○', '○', '○', '○', '○', '❤']action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]graph[row * 12 + col] = actiongraph = ''.join(graph)for i in range(0, 4 * 12, 12):print(graph[i:i + 12])from IPython import display
import timedef test():#起点row = random.choice(range(4))col = 0#最多玩N步for _ in range(200):#获取当前状态,如果状态是终点或者掉陷阱则终止if get_state(row, col) in ['trap', 'terminal']:break#选择最优动作action = Q_pi[row, col].argmax()#打印这个动作display.clear_output(wait=True)time.sleep(0.1)show(row, col, action)#执行动作row, col, reward = envirment(row, col, action)print(test())
#打印所有格子的动作倾向
for row in range(4):line = ''for col in range(12):action = Q_pi[row, col].argmax()action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]line += actionprint(line)
结果:
value表格指示的action
测试结果如下:
需要注意的是sarsa算法
是跟一个策略函数π\piπ相关联的,应该是通过π\piπ来获取at和at+1a_t和a_{t+1}at和at+1的,但是这个代码里没有策略函数π\piπ,所以直接用value表格
来求at和at+1a_t和a_{t+1}at和at+1了,sarsa算法
通常是在 Actor-Critic
中担任’裁判’