基本概念¶
天授把一个RL训练流程划分成了几个子模块:trainer(负责训练逻辑)、collector(负责数据采集)、policy(负责训练策略)和 buffer(负责数据存储),此外还有两个外围的模块,一个是env,一个是model(policy负责RL算法实现比如loss function的计算,model就只是个正常的神经网络)。下图描述了这些模块的依赖:
Batch¶
天授提供了 Batch
作为内部模块传递数据所使用的数据结构,它既像字典又像数组,可以以这两种方式组织数据和访问数据,像下面这样:
>>> import torch, numpy as np
>>> from tianshou.data import Batch
>>> data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))
>>> # 注意,list会自动变成numpy
>>> data.b
array([5, 5])
>>> data.b = np.array([3, 4, 5])
>>> print(data)
Batch(
a: 4,
b: array([3, 4, 5]),
c: '2312312',
d: array(['a', '-2', '-3'], dtype=object),
)
>>> data = Batch(obs={'index': np.zeros((2, 3))}, act=torch.zeros((2, 2)))
>>> data[:, 1] += 6
>>> print(data[-1])
Batch(
obs: Batch(
index: array([0., 6., 0.]),
),
act: tensor([0., 6.]),
)
总之就是可以定义任何key-value放在Batch里面,然后可以做一些常规的操作比如+-*、cat/stack之类的。Understand Batch 里面详细描述了Batch的各种用法,非常值得一看(虽然它是英文的但只要看代码也还行?)。
Buffer¶
ReplayBuffer
负责存储数据和采样出来数据用于policy的训练。目前天授保留了7个关键字在Buffer里面:
obs
时刻的观测值;act
时刻采取的动作值;rew
时刻环境返回的奖励函数值;done
时刻是否结束这个episode;obs_next
时刻的观测值;info
时刻环境给出的额外信息(gym.Env会返回4个东西,最后一个就是它);policy
时刻由policy计算出的需要额外存储的数据;
下面的代码片段展示了Buffer的一些典型用法:
>>> import pickle, numpy as np
>>> from tianshou.data import Batch, ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
>>> buf.obs # 因为设置了 size = 20,所以 len(buf.obs) == 20
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> # 但是里面只有3个合法的数据,因此 len(buf) == 3
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # 把buffer所有数据保存到 "buf.pkl"
>>> buf.save_hdf5('buf.hdf5') # 把buffer所有数据保存到 "buf.hdf5"
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... done = i % 4 == 0
... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
>>> len(buf2)
10
>>> buf2.obs # 因为 buf2 的 size = 10,所以它只会存储最后10步的结果
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
>>> buf.update(buf2) # 把 buf2 的数据挪到buf里面,同时保持相对时间顺序
>>> buf.obs
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
0, 0, 0, 0])
>>> indice = buf.sample_index(0) # 使用 batchsize=0 来获取buffer里面的全部数据
>>> indice
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> buf.prev(indice) # 给定index,计算上一个transition所对应的index
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
>>> buf.next(indice) # 给定index,计算下一个transition所对应的index
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
>>> # 从buffer里面拿一个随机的数据,batch_data就是buf[indice]
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # 从"buf.pkl"文件恢复出buffer
>>> len(buf)
3
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5') # 从"buf.hdf5"导入完整的buffer
>>> len(buf)
3
ReplayBuffer
还支持堆叠采样(为了RNN,详情查看 Issue 19)、不存储obs_next(为了省些内存),以及任意类型的数据结构存储(这个是Batch支持的):
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... ptr, ep_rew, ep_len, ep_idx = buf.add(
... Batch(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1}))
... print(i, ep_len, ep_rew)
0 [1] [0.]
1 [0] [0.]
2 [0] [0.]
3 [0] [0.]
4 [0] [0.]
5 [5] [15.]
6 [0] [0.]
7 [0] [0.]
8 [0] [0.]
9 [0] [0.]
10 [5] [40.]
11 [0] [0.]
12 [0] [0.]
13 [0] [0.]
14 [0] [0.]
15 [5] [65.]
>>> print(buf) # 可以发现obs_next并不在里面存着
ReplayBuffer(
obs: Batch(
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
),
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([False, True, False, False, False, False, True, False,
False]),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7 7 8 9]
[ 7 8 9 10]
[11 11 11 11]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[ 7 7 7 7]
[ 7 7 7 8]]
>>> # 也可以这样取出stacked过的obs(注意stack只对obs/obs_next/info/policy有效)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0
>>> # 可以通过 __getitem__ 来弄出obs_next(虽然并没存),但是[:]会按照时间顺序(而不是实际存储顺序)来取数据
>>> # 比如下面这个就相当于 index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
>>> print(buf[:].obs_next.id)
[[ 7 7 7 8]
[ 7 7 8 9]
[ 7 8 9 10]
[ 7 8 9 10]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[12 13 14 15]]
>>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
>>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
True
天授还提供了其他类型的buffer比如 PrioritizedReplayBuffer
(基于线段树)、VectorReplayBuffer
(能够向其中添加不同episode的数据的同时维护时间顺序)。可以访问对应的文档来查看。
Policy¶
天授把一个RL算法用一个继承自 BasePolicy
的类来实现,主要的部分有如下几个:
__init__()
:策略初始化,比如初始化自定义的模型等;forward()
:根据给定的观测值obs,计算出动作值action;process_fn()
:在获取训练数据之前和buffer进行交互,比如使用GAE或者nstep算法来估计优势函数;learn()
:使用一个Batch的数据进行策略的更新;post_process_fn()
:使用一个Batch的数据进行Buffer的更新(比如更新PER);update()
:最主要的接口。这个update函数先是从buffer采样出一个batch,然后调用process_fn预处理,然后learn更新策略,然后 post_process_fn完成一次迭代:process_fn -> learn -> post_process_fn
。
各种状态和阶段¶
强化学习训练流程可以分为两个部分:训练部分(Training state)和测试部分(Testing State),而训练部分可以细分为采集数据阶段(Collecting state)和更新策略阶段(Updating state),两个阶段在训练过程中交替进行。 顾名思义,采集数据阶段是由collector负责的,而策略更新阶段是由policy.update负责的。
为了区分上述这些状态,可以通过检查 policy.training
和 policy.updating
来确定处于哪个状态,这边列了一张表方便查看:
State for policy |
policy.training |
policy.updating |
|
---|---|---|---|
Training state |
Collecting state |
True |
False |
Updating state |
True |
True |
|
Testing state |
False |
False |
policy.updating
实际情况下主要用于exploration,比如在各种Q-Learning算法中,在不同的policy state切换探索策略。
policy.forward¶
forward
函数接收obs计算action,输入和输出由于算法的不同而不同,但大部分情况下是这样的:(batch, state, ...) -> batch
。
输入的Batch是环境中给出的数据(observation、reward、done 和 info),要么来自 tianshou.data.Collector.collect
(Collecting state),要么来自``tianshou.data.ReplayBuffer.sample``(Updating state)。Batch里面的所有数据第一维都是batch-size。
输出也是一个Batch,必须包含 act
关键字,可能包含 state
关键字(用于存放hiddle state,RNN使用)、policy
关键字(policy计算过程中需要存储到buffer里面的中间结果,比如logprob之类的,后续更新网络需要用到),以及其他key(只不过不会被存储到buffer里面)。
比如您想要使用policy单独来evaluate一个episode,不用collect给出的函数,可以像下面这样做:
# env 是 gym.Env
obs, done = env.reset(), False
while not done:
batch = Batch(obs=[obs]) # 第一维是 batch size
act = policy(batch).act[0] # policy.forward 返回一个 batch,使用 ".act" 来取出里面action的数据
obs, rew, done, info = env.step(act)
这边 Batch(obs=[obs])
会自动为obs下面的所有数据创建第0维,让它为batch size=1,否则神经网络没法确定batch size。
policy.process_fn¶
process_fn
用于计算时间相关的序列信息,比如计算n-step returns或者GAE returns。这边拿2-step DQN举例,公式是
是 discount factor,。下面给出了未使用天授的训练过程伪代码:
s = env.reset()
buffer = Buffer(size=10000)
agent = DQN()
for i in range(int(1e6)):
a = agent.compute_action(s)
s_, r, d, _ = env.step(a)
buffer.store(s, a, s_, r, d)
s = s_
if i % 1000 == 0:
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
# 计算 2-step returns,咋算呢?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...)
# 更新 DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret)
从上面伪代码可以看出我们需要一个依赖于时间相关的接口来计算2-step returns。process_fn()
就是用来做这件事的,给它一个replay buffer、采样用的index(相当于时间t)和采样出来的batch就能计算。因为在buffer里面我们按照时间顺序存储各种数据,因此2-step returns的计算可以像下面这样简单:
class DQN_2step(BasePolicy):
"""其他的代码"""
def process_fn(self, batch, buffer, indice):
buffer_len = len(buffer)
batch_2 = buffer[(indice + 2) % buffer_len]
# 上面这个代码访问batch_2.obs就是s_{t+2},也可以像下面这样访问:
# batch_2_obs = buffer.obs[(indice + 2) % buffer_len]
# 总之就是 buffer.obs[i] 和 buffer[i].obs是一个意思,但是前面的这种写法效率更高
Q = self(batch_2, eps=0) # shape: (batch_size, action_shape)
maxQ = Q.max(dim=-1)
batch.returns = batch.rew \
+ self._gamma * buffer.rew[(indice + 1) % buffer_len] \
+ self._gamma ** 2 * maxQ
return batch
上面这个代码并没考虑 done = True
的情况,因此正确性不能保证,但是它展示了两种能够访问到 的方法。
至于policy的其他功能,可以参考 tianshou.policy,在最下面给出了一个宏观解释:宏观解释。
Collector¶
Collector
负责policy与env的交互和数据存储。collect()
是collector的主要方法,它能够指定让policy和环境交互给定数目 n_step
个step或者 n_episode
个episode,并把该过程中产生的数据存储到buffer中。
宏观解释 给出了一个宏观层面的解释,其他collector的功能可参考对应文档。此处列出一些常用用法:
policy = PGPolicy(...) # 或者其他policy都可以
env = gym.make("CartPole-v0")
replay_buffer = ReplayBuffer(size=10000)
# 这里单个env对应ReplayBuffer
collector = Collector(policy, env, buffer=replay_buffer)
# 多个env的话得用VectorReplayBuffer,但是collector仍然适用
vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3)
# buffer_num推荐和env数量相等
envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)])
collector = Collector(policy, envs, buffer=vec_buffer)
# 收集3个episode
collector.collect(n_episode=3)
# 收集至少俩step(这个会收集三个,因为有三个env,每次收集的次数得是3的倍数)
collector.collect(n_step=2)
# 边收集变直播,使用render参数就可以(render传入的是时间间隔,以秒为单位)
collector.collect(n_episode=1, render=0.03)
还有个:class:~tianshou.data.AsyncCollector,继承了:class:~tianshou.data.Collector,它支持异步的环境采样(比如环境很慢或者step时间差异很大)。不过AsyncCollector的collect的语义和上面Collector有所不同,由于异步的特性,它只能保证**至少** n_step
或者 n_episode
地收集数据。
Trainer¶
有了之前声明的collector和policy之后,就可以用trainer把它们包起来。Trainer负责最上层训练逻辑的控制,例如训练多少次之后进行策略和环境的交互。现有的训练器包括同策略学习训练器(On-policy Trainer)和异策略学习训练器(Off-policy Trainer)。
天授未显式地将训练器抽象成一个类,因为在其他现有平台中都将类似训练器的实现抽象封装成一个类,导致用户难以二次开发。因此以函数的方式实现训练器,并提供了示例代码便于研究者进行定制化训练策略的开发。可以参考 定制化训练器。
宏观解释¶
接下来将通过一段伪代码的讲解来阐释上述所有抽象模块的应用。
# pseudocode, cannot work # 对应天授实现
s = env.reset() # 环境初始化,在env中实现
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # 在Trainer中实现
a = agent.compute_action(s) # act = policy(batch, ...).act
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # 在Trainer中实现
# the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indice = buffer.sample(batch_size)
# 计算 2-step returns,咋算呢?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# 更新 DQN policy
agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...)