驱动程序

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

简介

强化学习中的一种常见模式是在环境中执行策略,以指定数量的步骤或回合。例如,这在数据收集、评估和生成代理视频期间发生。

虽然这在 python 中相对容易编写,但在 TensorFlow 中编写和调试要复杂得多,因为它涉及 tf.while 循环、tf.condtf.control_dependencies。因此,我们将这种运行循环的概念抽象到一个名为 driver 的类中,并在 Python 和 TensorFlow 中提供经过良好测试的实现。

此外,驱动程序在每个步骤中遇到的数据将保存在名为 Trajectory 的命名元组中,并广播到一组观察者,例如回放缓冲区和指标。此数据包括来自环境的观察结果、策略推荐的操作、获得的奖励、当前和下一步的类型等。

设置

如果您尚未安装 tf-agents 或 gym,请运行

pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver

Python 驱动程序

PyDriver 类接受一个 python 环境、一个 python 策略和一个在每个步骤更新的观察者列表。主要方法是 run(),它使用策略中的操作对环境进行步进,直到满足以下终止条件之一:步骤数达到 max_steps 或回合数达到 max_episodes

实现大致如下

class PyDriver(object):

  def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
    self._env = env
    self._policy = policy
    self._observers = observers or []
    self._max_steps = max_steps or np.inf
    self._max_episodes = max_episodes or np.inf

  def run(self, time_step, policy_state=()):
    num_steps = 0
    num_episodes = 0
    while num_steps < self._max_steps and num_episodes < self._max_episodes:

      # Compute an action using the policy for the given time_step
      action_step = self._policy.action(time_step, policy_state)

      # Apply the action to the environment and get the next step
      next_time_step = self._env.step(action_step.action)

      # Package information into a trajectory
      traj = trajectory.Trajectory(
         time_step.step_type,
         time_step.observation,
         action_step.action,
         action_step.info,
         next_time_step.step_type,
         next_time_step.reward,
         next_time_step.discount)

      for observer in self._observers:
        observer(traj)

      # Update statistics to check termination
      num_episodes += np.sum(traj.is_last())
      num_steps += np.sum(~traj.is_boundary())

      time_step = next_time_step
      policy_state = action_step.state

    return time_step, policy_state

现在,让我们运行一个在 CartPole 环境中运行随机策略的示例,将结果保存到回放缓冲区并计算一些指标。

env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
    env, policy, observers, max_steps=20, max_episodes=1)

initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)

print('Replay Buffer:')
for traj in replay_buffer:
  print(traj)

print('Average Return: ', metric.result())
Replay Buffer:
Trajectory(
{'step_type': array(0, dtype=int32),
 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00317699,  0.16732468, -0.02837953, -0.3210437 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00597585, -0.22198795, -0.03554928,  0.24405919], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00153609, -0.41658458, -0.0306681 ,  0.5253204 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.0067956 , -0.61126184, -0.02016169,  0.80818397], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.01902084, -0.8061018 , -0.00399801,  1.0944574 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.03514287, -0.6109274 ,  0.01789114,  0.8005227 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.04736142, -0.8062901 ,  0.03390159,  1.0987796 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.06348722, -0.61163044,  0.05587719,  0.816923  ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.07571983, -0.41731614,  0.07221565,  0.54232585], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.08406615, -0.61337477,  0.08306216,  0.8568603 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.09633365, -0.8095243 ,  0.10019937,  1.1744623 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.11252414, -0.6158369 ,  0.12368862,  0.91479784], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.12484087, -0.8123951 ,  0.14198457,  1.2436544 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.14108877, -0.61935145,  0.16685766,  0.9986062 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1534758 , -0.42680538,  0.18682979,  0.7626272 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1620119 , -0.23468053,  0.20208232,  0.5340639 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(2, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(0., dtype=float32)})
Trajectory(
{'step_type': array(2, dtype=int32),
 'observation': array([-0.16670552, -0.43198496,  0.21276361,  0.8830067 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(0, dtype=int32),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32)})
Average Return:  18.0

TensorFlow 驱动程序

我们还在 TensorFlow 中有驱动程序,它们在功能上与 Python 驱动程序类似,但使用 TF 环境、TF 策略、TF 观察者等。我们目前有两个 TensorFlow 驱动程序:DynamicStepDriver,它在给定数量的(有效)环境步骤后终止,以及 DynamicEpisodeDriver,它在给定数量的回合后终止。让我们看一个 DynamicEpisode 在行动中的示例。

env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)

tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
                                            time_step_spec=tf_env.time_step_spec())


num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env, tf_policy, observers, num_episodes=2)

# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.0367443 ,  0.00652178,  0.04001181, -0.00376746]],
      dtype=float32)>})
Number of Steps:  34
Number of Episodes:  2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)

print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.04702466, -0.04836502,  0.01751254, -0.00393545]],
      dtype=float32)>})
Number of Steps:  63
Number of Episodes:  4