调谐器组件会调整模型的超参数。
调谐器组件和 KerasTuner 库
调谐器组件广泛使用 Python KerasTuner API 来调整超参数。
组件
调谐器需要
- 用于训练和评估的 tf.Examples。
- 用户提供的模块文件(或模块 fn),用于定义调整逻辑,包括模型定义、超参数搜索空间、目标等。
- Protobuf 定义的训练参数和评估参数。
- (可选)Protobuf 定义的调整参数。
- (可选)由上游 Transform 组件生成的转换图。
- (可选)由 SchemaGen 管道组件创建并由开发人员可选修改的数据架构。
使用给定的数据、模型和目标,调谐器会调整超参数并发出最佳结果。
说明
调谐器需要具有以下签名的用户模块函数 tuner_fn
...
from keras_tuner.engine import base_tuner
TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
('fit_kwargs', Dict[Text, Any])])
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
"""Build the tuner using the KerasTuner API.
Args:
fn_args: Holds args as name/value pairs.
- working_dir: working dir for tuning.
- train_files: List of file paths containing training tf.Example data.
- eval_files: List of file paths containing eval tf.Example data.
- train_steps: number of train steps.
- eval_steps: number of eval steps.
- schema_path: optional schema of the input data.
- transform_graph_path: optional transform graph produced by TFT.
Returns:
A namedtuple contains the following:
- tuner: A BaseTuner that will be used for tuning.
- fit_kwargs: Args to pass to tuner's run_trial function for fitting the
model , e.g., the training and validation dataset. Required
args depend on the above tuner's implementation.
"""
...
在此函数中,您定义模型和超参数搜索空间,并选择调整的目标和算法。调谐器组件将此模块代码作为输入,调整超参数,并发出最佳结果。
Trainer 可以将调谐器的输出超参数作为输入,并在其用户模块代码中使用它们。管道定义如下所示
...
tuner = Tuner(
module_file=module_file, # Contains `tuner_fn`.
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=20),
eval_args=trainer_pb2.EvalArgs(num_steps=5))
trainer = Trainer(
module_file=module_file, # Contains `run_fn`.
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
# This will be passed to `run_fn`.
hyperparameters=tuner.outputs['best_hyperparameters'],
train_args=trainer_pb2.TrainArgs(num_steps=100),
eval_args=trainer_pb2.EvalArgs(num_steps=5))
...
您可能不希望每次重新训练模型时都调整超参数。一旦您使用调谐器确定了一组良好的超参数,就可以从管道中删除调谐器,并使用 ImporterNode
从以前的训练运行中导入调谐器工件以馈送到 Trainer。
hparams_importer = Importer(
# This can be Tuner's output file or manually edited file. The file contains
# text format of hyperparameters (keras_tuner.HyperParameters.get_config())
source_uri='path/to/best_hyperparameters.txt',
artifact_type=HyperParameters,
).with_id('import_hparams')
trainer = Trainer(
...
# An alternative is directly use the tuned hyperparameters in Trainer's user
# module code and set hyperparameters to None here.
hyperparameters = hparams_importer.outputs['result'])
在 Google Cloud Platform (GCP) 上进行调整
在 Google Cloud Platform (GCP) 上运行时,调谐器组件可以利用两项服务
- AI Platform Vizier(通过 CloudTuner 实现)
- AI Platform Training(作为分布式调整的集群管理器)
AI Platform Vizier 作为超参数调整的后端
AI Platform Vizier 是一种托管服务,它基于 Google Vizier 技术执行黑盒优化。
CloudTuner 是 KerasTuner 的实现,它使用 AI Platform Vizier 服务作为研究后端。由于 CloudTuner 是 keras_tuner.Tuner
的子类,因此它可以作为 tuner_fn
模块中的直接替换,并作为 TFX Tuner 组件的一部分执行。
以下代码片段展示了如何使用 CloudTuner
。请注意,配置 CloudTuner
需要特定于 GCP 的项目,例如 project_id
和 region
。
...
from tensorflow_cloud import CloudTuner
...
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
"""An implementation of tuner_fn that instantiates CloudTuner."""
...
tuner = CloudTuner(
_build_model,
hyperparameters=...,
...
project_id=..., # GCP Project ID
region=..., # GCP Region where Vizier service is run.
)
...
return TuneFnResult(
tuner=tuner,
fit_kwargs={...}
)
在 Cloud AI Platform Training 分布式工作器集群上进行并行调优
作为 Tuner 组件底层实现的 KerasTuner 框架能够并行执行超参数搜索。虽然默认的 Tuner 组件无法并行执行多个搜索工作器,但通过使用 Google Cloud AI Platform 扩展 Tuner 组件,它提供了使用 AI Platform Training 作业作为分布式工作器集群管理器来运行并行调优的能力。 TuneArgs 是提供给此组件的配置。它是默认 Tuner 组件的直接替换。
tuner = google_cloud_ai_platform.Tuner(
... # Same kwargs as the above stock Tuner component.
tune_args=proto.TuneArgs(num_parallel_trials=3), # 3-worker parallel
custom_config={
# Configures Cloud AI Platform-specific configs . For for details, see
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
TUNING_ARGS_KEY:
{
'project': ...,
'region': ...,
# Configuration of machines for each master/worker in the flock.
'masterConfig': ...,
'workerConfig': ...,
...
}
})
...
扩展 Tuner 组件的行为和输出与默认 Tuner 组件相同,只是多个超参数搜索在不同的工作器机器上并行执行,因此 num_trials
将更快完成。当搜索算法可以轻松并行化时,这尤其有效,例如 RandomSearch
。但是,如果搜索算法使用先前试验结果的信息,例如 AI Platform Vizier 中实现的 Google Vizier 算法,则过度并行搜索会对搜索的有效性产生负面影响。
链接
更多详细信息可在 Tuner API 参考 中找到。