使用 TF Profiler 分析 tf.data 性能

概述

本指南假设您熟悉 TensorFlow Profilertf.data。它旨在提供分步说明和示例,以帮助用户诊断和解决输入管道性能问题。

首先,收集 TensorFlow 作业的配置文件。有关如何执行此操作的说明适用于 CPU/GPUCloud TPU

TensorFlow Trace Viewer

下面详细介绍的分析工作流程侧重于 Profiler 中的跟踪查看器工具。此工具显示一个时间线,该时间线显示 TensorFlow 程序执行的操作的持续时间,并允许您识别哪些操作执行时间最长。有关跟踪查看器的更多信息,请查看 TF Profiler 指南的这一部分。通常,tf.data 事件将出现在主机 CPU 时间线上。

分析工作流程

请按照以下工作流程操作。如果您有任何反馈可以帮助我们改进它,请 创建一个带有“comp:data”标签的 github 问题

1. 您的 tf.data 管道是否以足够快的速度生成数据?

首先确定输入管道是否是 TensorFlow 程序的瓶颈。

为此,请在跟踪查看器中查找 IteratorGetNext::DoCompute 操作。通常,您希望在步骤开始时看到这些操作。这些片段代表输入管道在请求时生成一批元素所需的时间。如果您使用的是 keras 或在 tf.function 中迭代数据集,则应在 tf_data_iterator_get_next 线程中找到这些片段。

请注意,如果您使用的是 分布式策略,您可能会看到 IteratorGetNextAsOptional::DoCompute 事件,而不是 IteratorGetNext::DoCompute(截至 TF 2.3)。

image

如果调用快速返回(<= 50 us),这意味着您的数据在请求时可用。输入管道不是您的瓶颈;有关更通用的性能分析技巧,请参阅性能分析器指南

image

如果调用返回缓慢,tf.data 无法跟上消费者的请求。继续下一节。

2. 您是否正在预取数据?

输入管道性能的最佳实践是在您的 tf.data 管道末尾插入一个 tf.data.Dataset.prefetch 变换。此变换将输入管道的预处理计算与模型计算的下一步重叠,并且在训练模型时,它是获得最佳输入管道性能所必需的。如果您正在预取数据,您应该在与 IteratorGetNext::DoCompute 操作相同的线程上看到一个 Iterator::Prefetch 片段。

image

如果您在管道末尾没有 prefetch您应该添加一个。有关 tf.data 性能建议的更多信息,请参阅tf.data 性能指南

如果您已经预取了数据,并且输入管道仍然是您的瓶颈,请继续下一节以进一步分析性能。

3. 您是否达到了高 CPU 利用率?

tf.data 通过尝试尽可能充分利用可用资源来实现高吞吐量。通常,即使在 GPU 或 TPU 等加速器上运行模型时,tf.data 管道也会在 CPU 上运行。您可以使用 sarhtop 等工具检查您的利用率,或者如果您在 GCP 上运行,则在 云监控控制台 中检查。

如果您的利用率很低,这表明您的输入管道可能没有充分利用主机 CPU。您应该查阅tf.data 性能指南 以了解最佳实践。如果您已经应用了最佳实践,并且利用率和吞吐量仍然很低,请继续到下面的瓶颈分析

如果您的利用率接近资源限制,为了进一步提高性能,您需要提高输入管道的效率(例如,避免不必要的计算)或卸载计算。

您可以通过避免在 tf.data 中进行不必要的计算来提高输入管道的效率。一种方法是在计算密集型工作之后插入一个 tf.data.Dataset.cache 变换,前提是您的数据适合内存;这以增加内存使用为代价减少了计算。此外,在 tf.data 中禁用操作内并行性有可能将效率提高 > 10%,并且可以通过在您的输入管道上设置以下选项来完成

dataset = ...
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)

4. 瓶颈分析

以下部分将逐步介绍如何在跟踪查看器中读取 tf.data 事件,以了解瓶颈在哪里以及可能的缓解策略。

了解性能分析器中的 tf.data 事件

性能分析器中的每个 tf.data 事件都具有名称 Iterator::<Dataset>,其中 <Dataset> 是数据集源或变换的名称。每个事件还具有长名称 Iterator::<Dataset_1>::...::<Dataset_n>,您可以通过单击 tf.data 事件来查看。在长名称中,<Dataset_n> 与来自(短)名称的 <Dataset> 相匹配,长名称中的其他数据集表示下游变换。

image

例如,上面的屏幕截图是从以下代码生成的

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)

这里,Iterator::Map 事件具有长名称 Iterator::BatchV2::FiniteRepeat::Map。请注意,数据集名称可能与 Python API 略有不同(例如,FiniteRepeat 而不是 Repeat),但应该足够直观以供解析。

同步和异步变换

对于同步 tf.data 变换(例如 BatchMap),您将在同一线程上看到来自上游变换的事件。在上面的示例中,由于所有使用的变换都是同步的,因此所有事件都出现在同一线程上。

对于异步变换(例如 PrefetchParallelMapParallelInterleaveMapAndBatch),来自上游变换的事件将在不同的线程上。在这种情况下,“长名称”可以帮助您识别管道中事件对应的变换。

image

例如,上面的屏幕截图是从以下代码生成的

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)
dataset = dataset.prefetch(1)

这里,Iterator::Prefetch 事件位于 tf_data_iterator_get_next 线程上。由于 Prefetch 是异步的,因此它的输入事件 (BatchV2) 将位于不同的线程上,并且可以通过搜索长名称 Iterator::Prefetch::BatchV2 来定位。在这种情况下,它们位于 tf_data_iterator_resource 线程上。从它的长名称,您可以推断出 BatchV2Prefetch 的上游。此外,BatchV2 事件的 parent_id 将与 Prefetch 事件的 ID 相匹配。

识别瓶颈

通常,要识别输入管道中的瓶颈,请从最外层的变换一直走到源头。从管道中的最终变换开始,递归到上游变换,直到找到一个缓慢的变换或到达源数据集,例如 TFRecord。在上面的示例中,您将从 Prefetch 开始,然后向上游走到 BatchV2FiniteRepeatMap,最后是 Range

通常,一个缓慢的变换对应于一个事件很长,但输入事件很短的变换。下面是一些示例。

请注意,大多数主机输入管道中的最终(最外层)变换是 Iterator::Model 事件。Model 变换由 tf.data 运行时自动引入,用于检测和自动调整输入管道性能。

如果您的作业使用 分布式策略,则跟踪查看器将包含与设备输入管道相对应的其他事件。设备管道的最外层变换(嵌套在 IteratorGetNextOp::DoComputeIteratorGetNextAsOptionalOp::DoCompute 下)将是一个 Iterator::Prefetch 事件,其上游是一个 Iterator::Generator 事件。您可以通过搜索 Iterator::Model 事件来找到相应的宿主管道。

示例 1

image

上面的屏幕截图是从以下输入管道生成的

dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(parse_record)
dataset = dataset.batch(32)
dataset = dataset.repeat()

在屏幕截图中,观察到 (1) Iterator::Map 事件很长,但 (2) 它的输入事件 (Iterator::FlatMap) 快速返回。这表明顺序 Map 变换是瓶颈。

请注意,在屏幕截图中,InstantiatedCapturedFunction::Run 事件对应于执行映射函数所需的时间。

示例 2

image

上面的屏幕截图是从以下输入管道生成的

dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(parse_record, num_parallel_calls=2)
dataset = dataset.batch(32)
dataset = dataset.repeat()

此示例类似于上面的示例,但使用 ParallelMap 而不是 Map。我们在这里注意到 (1) Iterator::ParallelMap 事件很长,但 (2) 它的输入事件 Iterator::FlatMap(由于 ParallelMap 是异步的,因此位于不同的线程上)很短。这表明 ParallelMap 变换是瓶颈。

解决瓶颈

源数据集

如果您已经确定数据集源是瓶颈,例如从 TFRecord 文件中读取,您可以通过并行化数据提取来提高性能。为此,请确保您的数据跨多个文件分片,并使用 tf.data.Dataset.interleave,并将 num_parallel_calls 参数设置为 tf.data.AUTOTUNE。如果确定性对您的程序不重要,您可以通过在 tf.data.Dataset.interleave 上设置 deterministic=False 标志来进一步提高性能(从 TF 2.2 开始)。例如,如果您从 TFRecords 中读取,您可以执行以下操作

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(tf.data.TFRecordDataset,
  num_parallel_calls=tf.data.AUTOTUNE,
  deterministic=False)

请注意,分片文件应该足够大以摊销打开文件的开销。有关并行数据提取的更多详细信息,请参阅tf.data 性能指南的这一部分

变换数据集

如果您已经确定中间 tf.data 变换是瓶颈,您可以通过并行化变换或缓存计算 来解决它,前提是您的数据适合内存并且它适用。一些变换(例如 Map)具有并行对应项;tf.data 性能指南演示了 如何并行化这些变换。其他变换,例如 FilterUnbatchBatch 本质上是顺序的;您可以通过引入“外部并行性”来并行化它们。例如,假设您的输入管道最初看起来像下面这样,其中 Batch 是瓶颈

filenames = tf.data.Dataset.list_files(file_path, shuffle=is_training)
dataset = filenames_to_dataset(filenames)
dataset = dataset.batch(batch_size)

您可以通过在分片输入上运行输入管道的多个副本并组合结果来引入“外部并行性”

filenames = tf.data.Dataset.list_files(file_path, shuffle=is_training)

def make_dataset(shard_index):
  filenames = filenames.shard(NUM_SHARDS, shard_index)
  dataset = filenames_to_dataset(filenames)
  Return dataset.batch(batch_size)

indices = tf.data.Dataset.range(NUM_SHARDS)
dataset = indices.interleave(make_dataset,
                             num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

其他资源