在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
TensorFlow 提供了一组伪随机数生成器 (RNG),位于 tf.random
模块中。本文档介绍了如何控制随机数生成器,以及这些生成器如何与其他 TensorFlow 子系统交互。
TensorFlow 提供了两种方法来控制随机数生成过程
通过显式使用
tf.random.Generator
对象。每个这样的对象都维护一个状态(在tf.Variable
中),该状态将在每次数字生成后更改。通过纯函数式无状态随机函数,例如
tf.random.stateless_uniform
。使用相同的参数(包括种子)并在同一设备上调用这些函数将始终产生相同的结果。
设置
import tensorflow as tf
# Creates some virtual devices (cpu:0, cpu:1, etc.) for using distribution strategy
physical_devices = tf.config.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration(
physical_devices[0], [
tf.config.experimental.VirtualDeviceConfiguration(),
tf.config.experimental.VirtualDeviceConfiguration(),
tf.config.experimental.VirtualDeviceConfiguration()
])
2024-01-17 02:22:51.386100: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-17 02:22:51.386148: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-17 02:22:51.387696: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
tf.random.Generator
类
当您希望每次 RNG 调用都产生不同的结果时,使用 tf.random.Generator
类。它维护一个内部状态(由 tf.Variable
对象管理),该状态将在每次生成随机数时更新。由于状态由 tf.Variable
管理,因此它可以享受 tf.Variable
提供的所有功能,例如轻松检查点、自动控制依赖关系和线程安全。
您可以通过手动创建类的对象来获取 tf.random.Generator
,或者调用 tf.random.get_global_generator()
来获取默认的全局生成器。
g1 = tf.random.Generator.from_seed(1)
print(g1.normal(shape=[2, 3]))
g2 = tf.random.get_global_generator()
print(g2.normal(shape=[2, 3]))
tf.Tensor( [[ 0.43842277 -0.53439844 -0.07710262] [ 1.5658045 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32) tf.Tensor( [[ 0.24077678 0.39891425 0.03557164] [-0.15206331 -0.7270625 1.8158559 ]], shape=(2, 3), dtype=float32)
有多种方法可以创建生成器对象。最简单的方法是 Generator.from_seed
,如上所示,它从种子创建生成器。种子可以是任何非负整数。 from_seed
还接受一个可选参数 alg
,它是此生成器将使用的 RNG 算法。
g1 = tf.random.Generator.from_seed(1, alg='philox')
print(g1.normal(shape=[2, 3]))
tf.Tensor( [[ 0.43842277 -0.53439844 -0.07710262] [ 1.5658045 -0.1012345 -0.2744976 ]], shape=(2, 3), dtype=float32)
有关此算法的更多信息,请参见下面的“算法”部分。
另一种创建生成器的方法是使用 Generator.from_non_deterministic_state
。以这种方式创建的生成器将从非确定性状态开始,具体取决于例如时间和操作系统。
g = tf.random.Generator.from_non_deterministic_state()
print(g.normal(shape=[2, 3]))
tf.Tensor( [[-0.8503367 -0.8919918 0.688985 ] [-0.51400167 0.57703274 -0.5177701 ]], shape=(2, 3), dtype=float32)
还有其他方法可以创建生成器,例如从显式状态创建,本指南未涵盖这些方法。
当使用 tf.random.get_global_generator
获取全局生成器时,您需要注意设备放置。全局生成器在第一次调用 tf.random.get_global_generator
时(从非确定性状态)创建,并放置在该调用时的默认设备上。因此,例如,如果您第一次调用 tf.random.get_global_generator
的位置在 tf.device("gpu")
范围内,则全局生成器将放置在 GPU 上,稍后从 CPU 使用全局生成器将导致 GPU 到 CPU 的复制。
还有一个函数 tf.random.set_global_generator
用于将全局生成器替换为另一个生成器对象。但是,应谨慎使用此函数,因为旧的全局生成器可能已被 tf.function
(作为弱引用)捕获,替换它将导致它被垃圾回收,从而破坏 tf.function
。重置全局生成器的更好方法是使用其中一个“重置”函数,例如 Generator.reset_from_seed
,它不会创建新的生成器对象。
g = tf.random.Generator.from_seed(1)
print(g.normal([]))
print(g.normal([]))
g.reset_from_seed(1)
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(0.43842277, shape=(), dtype=float32)
创建独立的随机数流
在许多应用程序中,需要多个独立的随机数流,独立是指它们不会重叠并且不会有任何统计上可检测的相关性。这是通过使用 Generator.split
来创建多个生成器来实现的,这些生成器保证彼此独立(即生成独立的流)。
g = tf.random.Generator.from_seed(1)
print(g.normal([]))
new_gs = g.split(3)
for new_g in new_gs:
print(new_g.normal([]))
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(2.536413, shape=(), dtype=float32) tf.Tensor(0.33186463, shape=(), dtype=float32) tf.Tensor(-0.07144657, shape=(), dtype=float32) tf.Tensor(-0.79253083, shape=(), dtype=float32)
split
将更改其调用的生成器(上面示例中的 g
)的状态,类似于 RNG 方法,例如 normal
。除了彼此独立之外,新的生成器 (new_gs
) 也保证独立于旧的生成器 (g
)。
当您想确保使用的生成器与其他计算位于同一设备上时,生成新的生成器也很有用,以避免跨设备复制的开销。例如
with tf.device("cpu"): # change "cpu" to the device you want
g = tf.random.get_global_generator().split(1)[0]
print(g.normal([])) # use of g won't cause cross-device copy, unlike the global generator
tf.Tensor(0.4637335, shape=(), dtype=float32)
您可以递归地进行拆分,在拆分的生成器上调用 split
。递归的深度没有限制(除了整数溢出)。
与 tf.function
的交互
tf.random.Generator
遵循与 tf.Variable
相同的规则,当与 tf.function
一起使用时。这包括三个方面。
在 tf.function
之外创建生成器
tf.function
可以使用在它之外创建的生成器。
g = tf.random.Generator.from_seed(1)
@tf.function
def foo():
return g.normal([])
print(foo())
tf.Tensor(0.43842277, shape=(), dtype=float32)
用户需要确保生成器对象在调用函数时仍然存在(未被垃圾回收)。
在 tf.function
内部创建生成器
在 tf.function
内部创建生成器只能在函数的第一次运行期间发生。
g = None
@tf.function
def foo():
global g
if g is None:
g = tf.random.Generator.from_seed(1)
return g.normal([])
print(foo())
print(foo())
tf.Tensor(0.43842277, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32)
将生成器作为参数传递给 tf.function
当用作 tf.function
的参数时,不同的生成器对象将导致 tf.function
的重新追踪。
num_traces = 0
@tf.function
def foo(g):
global num_traces
num_traces += 1
return g.normal([])
foo(tf.random.Generator.from_seed(1))
foo(tf.random.Generator.from_seed(2))
print(num_traces)
2
请注意,这种重新追踪行为与 tf.Variable
一致。
num_traces = 0
@tf.function
def foo(v):
global num_traces
num_traces += 1
return v.read_value()
foo(tf.Variable(1))
foo(tf.Variable(2))
print(num_traces)
1
与分布式策略的交互
Generator
与分布式策略的交互方式有两种。
在分布式策略之外创建生成器
如果在策略范围之外创建生成器,则所有副本对生成器的访问将被序列化,因此副本将获得不同的随机数。
g = tf.random.Generator.from_seed(1)
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
def f():
print(g.normal([]))
results = strat.run(f)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. tf.Tensor(0.43842274, shape=(), dtype=float32) tf.Tensor(1.6272374, shape=(), dtype=float32)
请注意,这种用法可能会导致性能问题,因为生成器的设备与副本不同。
在分布式策略内部创建生成器
如果在策略范围内创建生成器,则每个副本将获得不同的独立随机数流。
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) } WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) }
如果生成器已播种(例如,由 Generator.from_seed
创建),则随机数由种子决定,即使不同的副本获得不同的不相关数字。可以将副本上生成的随机数视为副本 ID 和所有副本共有的“主要”随机数的哈希值。因此,整个系统仍然是确定性的。
tf.random.Generator
也可以在 Strategy.run
内部创建。
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
def f():
g = tf.random.Generator.from_seed(1)
a = g.normal([])
b = g.normal([])
return tf.stack([a, b])
print(strat.run(f))
print(strat.run(f))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32), 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32) } WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. PerReplica:{ 0: tf.Tensor([-0.87930447 -1.5822568 ], shape=(2,), dtype=float32), 1: tf.Tensor([0.02066157 0.77539235], shape=(2,), dtype=float32) }
我们不再建议将 tf.random.Generator
作为参数传递给 Strategy.run
,因为 Strategy.run
通常希望参数是张量,而不是生成器。
保存生成器
通常,对于保存或序列化,您可以像处理 tf.Variable
或 tf.Module
(或其子类)一样处理 tf.random.Generator
。在 TF 中,有两种序列化机制:检查点 和 SavedModel。
检查点
可以使用 tf.train.Checkpoint
自由地保存和恢复生成器。从恢复点开始的随机数流将与从保存点开始的随机数流相同。
filename = "./checkpoint"
g = tf.random.Generator.from_seed(1)
cp = tf.train.Checkpoint(generator=g)
print(g.normal([]))
tf.Tensor(0.43842277, shape=(), dtype=float32)
cp.write(filename)
print("RNG stream from saving point:")
print(g.normal([]))
print(g.normal([]))
RNG stream from saving point: tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(1.6307176, shape=(), dtype=float32)
cp.restore(filename)
print("RNG stream from restoring point:")
print(g.normal([]))
print(g.normal([]))
RNG stream from restoring point: tf.Tensor(1.6272374, shape=(), dtype=float32) tf.Tensor(1.6307176, shape=(), dtype=float32)
您也可以在分布式策略中保存和恢复。
filename = "./checkpoint"
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
cp = tf.train.Checkpoint(my_generator=g)
print(strat.run(lambda: g.normal([])))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) }
with strat.scope():
cp.write(filename)
print("RNG stream from saving point:")
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
with strat.scope():
cp.restore(filename)
print("RNG stream from restoring point:")
print(strat.run(lambda: g.normal([])))
print(strat.run(lambda: g.normal([])))
RNG stream from restoring point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
您应该确保副本在保存之前不会在它们的 RNG 调用历史记录中出现分歧(例如,一个副本进行一次 RNG 调用,而另一个副本进行两次 RNG 调用)。否则,它们的内部 RNG 状态将出现分歧,并且 tf.train.Checkpoint
(它只保存第一个副本的状态)将无法正确恢复所有副本。
您也可以将保存的检查点恢复到具有不同副本数量的不同分布式策略。因为在策略范围内创建的 tf.random.Generator
对象只能在同一策略中使用,所以要恢复到不同的策略,您必须在目标策略中创建一个新的 tf.random.Generator
以及一个新的 tf.train.Checkpoint
,如本示例所示。
filename = "./checkpoint"
strat1 = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat1.scope():
g1 = tf.random.Generator.from_seed(1)
cp1 = tf.train.Checkpoint(my_generator=g1)
print(strat1.run(lambda: g1.normal([])))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-0.87930447, shape=(), dtype=float32), 1: tf.Tensor(0.020661574, shape=(), dtype=float32) }
with strat1.scope():
cp1.write(filename)
print("RNG stream from saving point:")
print(strat1.run(lambda: g1.normal([])))
print(strat1.run(lambda: g1.normal([])))
RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32) }
strat2 = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1", "cpu:2"])
with strat2.scope():
g2 = tf.random.Generator.from_seed(1)
cp2 = tf.train.Checkpoint(my_generator=g2)
cp2.restore(filename)
print("RNG stream from restoring point:")
print(strat2.run(lambda: g2.normal([])))
print(strat2.run(lambda: g2.normal([])))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1', '/job:localhost/replica:0/task:0/device:CPU:2') RNG stream from restoring point: PerReplica:{ 0: tf.Tensor(-1.5822568, shape=(), dtype=float32), 1: tf.Tensor(0.77539235, shape=(), dtype=float32), 2: tf.Tensor(0.6851049, shape=(), dtype=float32) } PerReplica:{ 0: tf.Tensor(-0.5039703, shape=(), dtype=float32), 1: tf.Tensor(0.1251838, shape=(), dtype=float32), 2: tf.Tensor(-0.58519536, shape=(), dtype=float32) }
虽然 g1
和 cp1
与 g2
和 cp2
是不同的对象,但它们通过公共检查点文件 filename
和对象名称 my_generator
相互链接。策略之间重叠的副本(例如上面的 cpu:0
和 cpu:1
)将像之前的示例一样正确恢复它们的 RNG 流。此保证不涵盖将生成器保存在策略范围内并在任何策略范围之外恢复或反之亦然的情况,因为策略之外的设备被视为与策略中的任何副本不同。
SavedModel
tf.random.Generator
可以保存到 SavedModel。生成器可以在策略范围内创建。保存也可以在策略范围内进行。
filename = "./saved_model"
class MyModule(tf.Module):
def __init__(self):
super(MyModule, self).__init__()
self.g = tf.random.Generator.from_seed(0)
@tf.function
def __call__(self):
return self.g.normal([])
@tf.function
def state(self):
return self.g.state
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
m = MyModule()
print(strat.run(m))
print("state:", m.state())
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1') PerReplica:{ 0: tf.Tensor(-1.4154755, shape=(), dtype=float32), 1: tf.Tensor(-0.11388441, shape=(), dtype=float32) } state: tf.Tensor([256 0 0], shape=(3,), dtype=int64)
with strat.scope():
tf.saved_model.save(m, filename)
print("RNG stream from saving point:")
print(strat.run(m))
print("state:", m.state())
print(strat.run(m))
print("state:", m.state())
INFO:tensorflow:Assets written to: ./saved_model/assets RNG stream from saving point: PerReplica:{ 0: tf.Tensor(-0.68758255, shape=(), dtype=float32), 1: tf.Tensor(0.8084062, shape=(), dtype=float32) } state: tf.Tensor([512 0 0], shape=(3,), dtype=int64) PerReplica:{ 0: tf.Tensor(-0.27342677, shape=(), dtype=float32), 1: tf.Tensor(-0.53093255, shape=(), dtype=float32) } state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)
imported = tf.saved_model.load(filename)
print("RNG stream from loading point:")
print("state:", imported.state())
print(imported())
print("state:", imported.state())
print(imported())
print("state:", imported.state())
RNG stream from loading point: state: tf.Tensor([256 0 0], shape=(3,), dtype=int64) tf.Tensor(-1.0359411, shape=(), dtype=float32) state: tf.Tensor([512 0 0], shape=(3,), dtype=int64) tf.Tensor(-0.06425078, shape=(), dtype=float32) state: tf.Tensor([768 0 0], shape=(3,), dtype=int64)
不建议将包含 tf.random.Generator
的 SavedModel 加载到分布式策略中,因为所有副本都将生成相同的随机数流(这是因为副本 ID 在 SavedModel 的图中被冻结)。
将分布式 tf.random.Generator
(在分布式策略中创建的生成器)加载到非策略环境中,例如上面的示例,也存在一个注意事项。RNG 状态将被正确恢复,但生成的随机数将与策略中原始生成器的随机数不同(同样是因为策略之外的设备被视为与策略中的任何副本不同)。
无状态 RNG
无状态 RNG 的使用很简单。由于它们只是纯函数,因此没有状态或副作用。
print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
print(tf.random.stateless_normal(shape=[2, 3], seed=[1, 2]))
tf.Tensor( [[ 0.5441101 0.20738031 0.07356433] [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32) tf.Tensor( [[ 0.5441101 0.20738031 0.07356433] [ 0.04643455 -1.30159 -0.95385665]], shape=(2, 3), dtype=float32)
每个无状态 RNG 都需要一个 seed
参数,该参数需要是一个形状为 [2]
的整数张量。操作的结果完全由这个种子决定。
无状态 RNG 使用的 RNG 算法是设备相关的,这意味着在不同设备上运行相同的操作可能会产生不同的输出。
算法
通用
tf.random.Generator
类和 stateless
函数都支持 Philox 算法(写成 "philox"
或 tf.random.Algorithm.PHILOX
),适用于所有设备。
如果使用相同的算法并从相同的状态开始,不同的设备将生成相同的整数。它们也将生成“几乎相同的”浮点数,尽管由于设备执行浮点计算的不同方式(例如,约简顺序)可能会存在微小的数值差异。
XLA 设备
在 XLA 驱动的设备(例如 TPU,以及启用 XLA 时 CPU/GPU)上,也支持 ThreeFry 算法(写成 "threefry"
或 tf.random.Algorithm.THREEFRY
)。该算法在 TPU 上速度很快,但在 CPU/GPU 上比 Philox 慢。
有关这些算法的更多详细信息,请参阅论文 'Parallel Random Numbers: As Easy as 1, 2, 3'。