分布式训练

分布式训练是一种模型训练方式,它将计算资源需求(如 CPU、内存)分配到多台计算机上。分布式训练能够实现更快的训练速度,并支持更大规模的数据集(多达数十亿个样本)。

分布式训练对于自动化超参数优化也非常有用,在这种场景下可以并行训练多个模型。

在本文中,您将学习如何

  • 使用分布式训练来训练 TF-DF 模型。
  • 使用分布式训练来调整 TF-DF 模型的超参数。

限制

目前,分布式训练支持:

如何启用分布式训练

本节列出了启用分布式训练的步骤。有关完整示例,请参见下一节。

ParameterServerStrategy 作用域

模型和数据集是在 ParameterServerStrategy 作用域内定义的。

strategy = tf.distribute.experimental.ParameterServerStrategy(...)
with strategy.scope():
  model = tfdf.keras.DistributedGradientBoostedTreesModel()
  distributed_train_dataset = strategy.distribute_datasets_from_function(dataset_fn)
model.fit(distributed_train_dataset)

数据集格式

与非分布式训练一样,数据集可以通过以下方式提供:

  1. 有限的 TensorFlow 分布式数据集,或
  2. 使用兼容数据集格式之一的数据集文件路径。

使用分片文件(sharded files)比使用有限的 TensorFlow 分布式数据集方法简单得多(1 行代码对比约 20 行代码)。但是,只有 TensorFlow 数据集方法支持 TensorFlow 预处理。如果您的流水线不包含任何预处理,建议使用分片数据集选项。

在这两种情况下,数据集都应分片成多个文件,以便有效地分配数据集读取任务。

设置工作节点 (Workers)

首席进程 (Chief process) 是运行定义 TensorFlow 模型的 Python 代码的程序。此进程不进行繁重的计算。实际的训练计算由工作节点 (Workers) 完成。工作节点是运行 TensorFlow 参数服务器的进程。

首席进程应配置工作节点的 IP 地址。这可以通过 TF_CONFIG 环境变量,或通过创建 ClusterResolver 来完成。更多详细信息,请参阅使用 ParameterServerStrategy 进行参数服务器训练

TensorFlow 的 ParameterServerStrategy 定义了两种类型的工作节点:“工作节点 (workers)”和“参数服务器 (parameter server)”。TensorFlow 要求每种类型的节点至少实例化一个。但是,TF-DF 仅使用“工作节点”。因此,需要实例化一个“参数服务器”,但它不会被 TF-DF 使用。例如,TF-DF 训练的配置可能如下所示:

  • 1 个首席进程 (Chief)
  • 50 个工作节点 (Workers)
  • 1 个参数服务器 (Parameter server)

工作节点需要访问 TensorFlow Decision Forests 的自定义训练算子。启用访问有两个选项:

  1. 使用预配置的 TF-DF C++ 参数服务器 //third_party/tensorflow_decision_forests/tensorflow/distribute:tensorflow_std_server
  2. 通过调用 tf.distribute.Server() 创建参数服务器。在这种情况下,应导入 TF-DF:import tensorflow_decision_forests

示例

本节展示了分布式训练配置的完整示例。如需更多示例,请查看 TF-DF 单元测试

示例:基于数据集路径的分布式训练

使用兼容数据集格式之一,将您的数据集划分为一组分片文件。建议按以下方式命名文件:/path/to/dataset/train-<5 位索引>-of-<文件总数>,例如:

/path/to/dataset/train-00000-of-00100
/path/to/dataset/train-00001-of-00005
/path/to/dataset/train-00002-of-00005
...

为了获得最高效率,文件数量应至少为工作节点数量的 10 倍。例如,如果您使用 100 个工作节点进行训练,请确保数据集至少被划分为 1000 个文件。

然后,可以使用如下的分片表达式引用这些文件:

  • /path/to/dataset/train@1000
  • /path/to/dataset/train@*

分布式训练执行方式如下。在此示例中,数据集存储为 TensorFlow Examples 的 TFRecord(由键 tfrecord+tfe 定义)。

import tensorflow_decision_forests as tfdf
import tensorflow as tf

strategy = tf.distribute.experimental.ParameterServerStrategy(...)

with strategy.scope():
  model = tfdf.keras.DistributedGradientBoostedTreesModel()

model.fit_on_dataset_path(
    train_path="/path/to/dataset/train@1000",
    label_key="label_key",
    dataset_format="tfrecord+tfe")

print("Trained model")
model.summary()

示例:基于有限 TensorFlow 分布式数据集的分布式训练

TF-DF 需要一个分布式的、有限的、按工作节点分片(worker-sharded)的 TensorFlow 数据集。

  • 分布式:非分布式数据集需包裹在 strategy.distribute_datasets_from_function 中。
  • 有限:数据集应恰好读取每个样本一次。数据集不应包含任何 repeat 指令。
  • 按工作节点分片:每个工作节点应读取数据集的不同部分。

以下是一个示例:

import tensorflow_decision_forests as tfdf
import tensorflow as tf


def dataset_fn(context, paths):
  """Create a worker-sharded finite dataset from paths.

  Like for non-distributed training, each example should be visited exactly
  once (and by only one worker) during the training. In addition, for optimal
  training speed, the reading of the examples should be distributed among the
  workers (instead of being read by a single worker, or read and discarded
  multiple times).

  In other words, don't add a "repeat" statement and make sure to shard the
  dataset at the file level and not at the example level.
  """

  # List the dataset files
  ds_path = tf.data.Dataset.from_tensor_slices(paths)

  # Make sure the dataset is used with distributed training.
  assert context is not None


  # Split the among the workers.
  #
  # Note: The "shard" is applied on the file path. The shard should not be
  # applied on the examples directly.
  # Note: You cannot use 'context.num_input_pipelines' with ParameterServerV2.
  current_worker = tfdf.keras.get_worker_idx_and_num_workers(context)
  ds_path = ds_path.shard(
      num_shards=current_worker.num_workers,
      index=current_worker.worker_idx)

  def read_csv_file(path):
    """Reads a single csv file."""

    numerical = tf.constant([0.0], dtype=tf.float32)
    categorical_string = tf.constant(["NA"], dtype=tf.string)
    csv_columns = [
        numerical,  # feature 1
        categorical_string,  # feature 2
        numerical,  # feature 3
        # ... define the features here.
    ]
    return tf.data.experimental.CsvDataset(path, csv_columns, header=True)

  ds_columns = ds_path.interleave(read_csv_file)

  # We assume a binary classification label with the following possible values.
  label_values = ["<=50K", ">50K"]

  # Convert the text labels into integers:
  # "<=50K" => 0
  # ">50K" => 1
  init_label_table = tf.lookup.KeyValueTensorInitializer(
      keys=tf.constant(label_values),
      values=tf.constant(range(label_values), dtype=tf.int64))
  label_table = tf.lookup.StaticVocabularyTable(
      init_label_table, num_oov_buckets=1)

  def extract_label(*columns):
    return columns[0:-1], label_table.lookup(columns[-1])

  ds_dataset = ds_columns.map(extract_label)

  # The batch size has no impact on the quality of the model. However, a larger
  # batch size generally is faster.
  ds_dataset = ds_dataset.batch(500)
  return ds_dataset


strategy = tf.distribute.experimental.ParameterServerStrategy(...)
with strategy.scope():
  model = tfdf.keras.DistributedGradientBoostedTreesModel()

  train_dataset = strategy.distribute_datasets_from_function(
      lambda context: dataset_fn(context, [...list of csv files...])
  )

model.fit(train_dataset)

print("Trained model")
model.summary()

示例:基于数据集路径的分布式超参数调优

基于数据集路径的分布式超参数调优与分布式训练类似。唯一的区别是该选项与非分布式模型兼容。例如,您可以分发(非分布式)梯度提升树模型的超参数调优任务。

with strategy.scope():
  tuner = tfdf.tuner.RandomSearch(num_trials=30, use_predefined_hps=True)
  model = tfdf.keras.GradientBoostedTreesModel(tuner=tuner)

training_history = model.fit_on_dataset_path(
  train_path=train_path,
  label_key=label,
  dataset_format="csv",
  valid_path=test_path)

logging.info("Trained model:")
model.summary()

示例:单元测试

为了对分布式训练进行单元测试,您可以创建模拟工作进程。有关更多信息,请参阅 TF-DF 单元测试中的 _create_in_process_tf_ps_cluster 方法。