版权所有 2023 The TF-Agents Authors。
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
简介
此示例展示了如何使用 TF-Agents 库在 Cartpole 环境中训练 REINFORCE 代理,类似于 DQN 教程。
我们将引导您完成强化学习 (RL) 管道中用于训练、评估和数据收集的所有组件。
设置
如果您尚未安装以下依赖项,请运行
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb
import tensorflow as tf
from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
2023-12-22 14:05:03.363396: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-22 14:05:03.363443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-22 14:05:03.365008: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
超参数
env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}
fc_layer_params = (100,)
learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}
环境
RL 中的环境代表我们试图解决的任务或问题。使用 suites
,可以在 TF-Agents 中轻松创建标准环境。我们有不同的 suites
用于从 OpenAI Gym、Atari、DM Control 等来源加载环境,前提是给定一个字符串环境名称。
现在让我们从 OpenAI Gym 套件中加载 CartPole 环境。
env = suite_gym.load(env_name)
我们可以渲染此环境以查看其外观。一根自由摆动的杆子连接到一辆小车上。目标是左右移动小车以使杆子保持向上指向。
env.reset()
PIL.Image.fromarray(env.render())
time_step = environment.step(action)
语句在环境中执行 action
。返回的 TimeStep
元组包含该动作的环境的下一个观察结果和奖励。环境中的 time_step_spec()
和 action_spec()
方法分别返回 time_step
和 action
的规范(类型、形状、边界)。
print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec: BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]) Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)
因此,我们看到观察结果是一个包含 4 个浮点数的数组:小车的位移和速度,以及杆子的角位移和速度。由于只有两种动作可能(向左移动或向右移动),因此 action_spec
是一个标量,其中 0 表示“向左移动”,1 表示“向右移动”。
time_step = env.reset()
print('Time step:')
print(time_step)
action = np.array(1, dtype=np.int32)
next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step: TimeStep( {'step_type': array(0, dtype=int32), 'reward': array(0., dtype=float32), 'discount': array(1., dtype=float32), 'observation': array([-0.00907558, 0.02627698, -0.01019297, 0.04808202], dtype=float32)}) Next time step: TimeStep( {'step_type': array(1, dtype=int32), 'reward': array(1., dtype=float32), 'discount': array(1., dtype=float32), 'observation': array([-0.00855004, 0.2215436 , -0.00923133, -0.24779937], dtype=float32)})
通常,我们创建两个环境:一个用于训练,另一个用于评估。大多数环境都是用纯 Python 编写的,但可以使用 TFPyEnvironment
包装器轻松将其转换为 TensorFlow。原始环境的 API 使用 numpy 数组,TFPyEnvironment
将这些数组转换为/从 Tensors
,以便您可以更轻松地与 TensorFlow 策略和代理进行交互。
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
代理
我们用来解决 RL 问题的算法表示为 Agent
。除了 REINFORCE 代理之外,TF-Agents 还提供了各种 Agents
的标准实现,例如 DQN、DDPG、TD3、PPO 和 SAC。
要创建 REINFORCE 代理,我们首先需要一个 Actor Network
,它可以学习根据环境的观察结果预测动作。
我们可以使用观察结果和动作的规范轻松创建 Actor Network
。我们可以指定网络中的层,在本例中,是 fc_layer_params
参数,该参数设置为一个 ints
元组,表示每个隐藏层的尺寸(请参阅上面的超参数部分)。
actor_net = actor_distribution_network.ActorDistributionNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
我们还需要一个 optimizer
来训练我们刚刚创建的网络,以及一个 train_step_counter
变量来跟踪网络更新的次数。
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
tf_agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=actor_net,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter)
tf_agent.initialize()
策略
在 TF-Agents 中,策略代表 RL 中策略的标准概念:给定一个 time_step
,生成一个动作或动作的分布。主要方法是 policy_step = policy.action(time_step)
,其中 policy_step
是一个名为元组的 PolicyStep(action, state, info)
。 policy_step.action
是要应用于环境的 action
,state
表示有状态(RNN)策略的状态,而 info
可能包含辅助信息,例如动作的对数概率。
代理包含两个策略:用于评估/部署的主要策略(agent.policy)以及另一个用于数据收集的策略(agent.collect_policy)。
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
指标和评估
用于评估策略的最常见指标是平均回报。回报是在环境中运行策略一个回合时获得的奖励总和,我们通常将其在几个回合中取平均值。我们可以按如下方式计算平均回报指标。
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
# Please also see the metrics module for standard implementations of different
# metrics.
回放缓冲区
为了跟踪从环境中收集的数据,我们将使用 Reverb,这是 Deepmind 开发的一个高效、可扩展且易于使用的回放系统。当我们收集轨迹时,它会存储体验数据,并在训练期间被使用。
此回放缓冲区是使用描述要存储的张量的规范构建的,这些规范可以从代理中使用 tf_agent.collect_data_spec
获得。
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature)
table = reverb.Table(
table_name,
max_size=replay_buffer_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
tf_agent.collect_data_spec,
table_name=table_name,
sequence_length=None,
local_server=reverb_server)
rb_observer = reverb_utils.ReverbAddEpisodeObserver(
replay_buffer.py_client,
table_name,
replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:162] Initializing TFRecordCheckpointer in /tmpfs/tmp/tmpkagdqs1n. [reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmpfs/tmp/tmpkagdqs1n [reverb/cc/platform/default/server.cc:71] Started replay server on port 41705
对于大多数代理,collect_data_spec
是一个包含观察结果、动作、奖励等的 Trajectory
名为元组。
数据收集
由于 REINFORCE 从整个回合中学习,因此我们定义了一个函数来使用给定的数据收集策略收集一个回合,并将数据(观察结果、动作、奖励等)作为轨迹保存在回放缓冲区中。这里我们使用“PyDriver”来运行体验收集循环。您可以在我们的 驱动程序教程 中了解更多关于 TF Agents 驱动程序的信息。
def collect_episode(environment, policy, num_episodes):
driver = py_driver.PyDriver(
environment,
py_tf_eager_policy.PyTFEagerPolicy(
policy, use_tf_function=True),
[rb_observer],
max_episodes=num_episodes)
initial_time_step = environment.reset()
driver.run(initial_time_step)
训练代理
训练循环包括从环境中收集数据和优化代理的网络。在此过程中,我们将偶尔评估代理的策略,以查看我们的进展情况。
以下操作大约需要 3 分钟才能完成。
try:
%%time
except:
pass
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
# Collect a few episodes using collect_policy and save to the replay buffer.
collect_episode(
train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)
# Use data from the buffer and update the agent's network.
iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
trajectories, _ = next(iterator)
train_loss = tf_agent.train(experience=trajectories)
replay_buffer.clear()
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1703253913.189247 48625 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. step = 25: loss = 1.8318419456481934 [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. step = 50: loss = 0.0070743560791015625 step = 50: Average Return = 9.800000190734863 [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. step = 75: loss = 1.1006038188934326 step = 100: loss = 0.5719594955444336 step = 100: Average Return = 50.29999923706055 step = 125: loss = -1.2458715438842773 [reverb/cc/client.cc:165] Sampler and server are owned by the same process (47292) so Table uniform_table is accessed directly without gRPC. step = 150: loss = 1.9363441467285156 step = 150: Average Return = 98.30000305175781 step = 175: loss = 0.8784818649291992 step = 200: loss = 1.9726766347885132 step = 200: Average Return = 143.6999969482422 step = 225: loss = 2.316105842590332 step = 250: loss = 2.5175299644470215 step = 250: Average Return = 191.5
可视化
绘图
我们可以绘制回报与全局步数的关系图,以查看我们代理的性能。在 Cartpole-v0
中,环境在每次杆子保持直立的时间步长中奖励 +1,并且由于最大步数为 200,因此最大可能的回报也为 200。
steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(0.7150002002716054, 250.0)
视频
通过在每一步渲染环境来可视化代理的性能非常有用。在我们这样做之前,让我们先创建一个函数来将视频嵌入到这个 colab 中。
def embed_mp4(filename):
"""Embeds an mp4 file in the notebook."""
video = open(filename,'rb').read()
b64 = base64.b64encode(video)
tag = '''
<video width="640" height="480" controls>
<source src="data:video/mp4;base64,{0}" type="video/mp4">
Your browser does not support the video tag.
</video>'''.format(b64.decode())
return IPython.display.HTML(tag)
以下代码可视化了代理在几个回合中的策略
num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
for _ in range(num_episodes):
time_step = eval_env.reset()
video.append_data(eval_py_env.render())
while not time_step.is_last():
action_step = tf_agent.policy.action(time_step)
time_step = eval_env.step(action_step.action)
video.append_data(eval_py_env.render())
embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned. [swscaler @ 0x5563cf186880] Warning: data is not aligned! This can lead to a speed loss