在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
概述
图正则化是神经图学习 (Bui 等人,2018) 这一更广泛范式下的特定技术。其核心思想是使用图正则化目标训练神经网络模型,利用标记数据和未标记数据。
在本教程中,我们将探讨使用图正则化对形成自然(有机)图的文档进行分类。
使用神经结构化学习 (NSL) 框架创建图正则化模型的一般方法如下
- 从输入图和样本特征生成训练数据。图中的节点对应于样本,图中的边对应于样本对之间的相似性。生成的训练数据除了原始节点特征外,还将包含邻居特征。
- 使用
Keras
顺序、函数或子类 API 创建神经网络作为基础模型。 - 使用 NSL 框架提供的
GraphRegularization
包装类包装基础模型,以创建一个新的图Keras
模型。此新模型将在其训练目标中包含图正则化损失作为正则化项。 - 训练和评估图
Keras
模型。
设置
安装神经结构化学习包。
pip install --quiet neural-structured-learning
依赖项和导入
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2023-11-16 12:04:49.460421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:49.460472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:49.461916: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Version: 2.15.0 Eager mode: True GPU is NOT AVAILABLE 2023-11-16 12:04:51.768240: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cora 数据集
Cora 数据集 是一个引文图,其中节点代表机器学习论文,边代表论文对之间的引文。所涉及的任务是文档分类,目标是将每篇论文归类到 7 个类别之一。换句话说,这是一个具有 7 个类别的多类分类问题。
图
原始图是有向的。但是,在本例中,我们考虑此图的无向版本。因此,如果论文 A 引用了论文 B,我们也认为论文 B 引用了论文 A。虽然这并不一定正确,但在本例中,我们将引文视为相似性的替代指标,相似性通常是可交换的属性。
特征
输入中的每篇论文实际上包含 2 个特征
词语:论文文本的密集多热词袋表示。Cora 数据集的词汇表包含 1433 个唯一词语。因此,此特征的长度为 1433,位置“i”处的值为 0/1,表示词汇表中的词语“i”是否存在于给定论文中。
标签:表示论文的类 ID(类别)的单个整数。
下载 Cora 数据集
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
将 Cora 数据转换为 NSL 格式
为了预处理 Cora 数据集并将其转换为神经结构化学习所需格式,我们将运行 'preprocess_cora_dataset.py' 脚本,该脚本包含在 NSL github 存储库中。此脚本执行以下操作
- 使用原始节点特征和图生成邻居特征。
- 生成包含
tf.train.Example
实例的训练和测试数据拆分。 - 以
TFRecord
格式持久化生成的训练和测试数据。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2023-11-16 12:04:52-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2023-11-16 12:04:53 (75.6 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2023-11-16 12:04:53.758687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:53.758743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:53.760530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-16 12:04:55.968449: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.44 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.05 minutes.
全局变量
训练和测试数据的文件路径基于上面运行 'preprocess_cora_dataset.py' 脚本时使用的命令行标志值。
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
超参数
我们将使用 HParams
的实例来包含用于训练和评估的各种超参数和常量。我们将在下面简要描述每个参数
num_classes:共有 7 个不同的类别
max_seq_length:这是词汇表的大小,输入中的所有实例都具有密集多热词袋表示。换句话说,单词的值为 1 表示单词存在于输入中,值为 0 表示不存在。
distance_type:这是用于对样本及其邻居进行正则化的距离度量。
graph_regularization_multiplier:这控制了图正则化项在整体损失函数中的相对权重。
num_neighbors:用于图正则化的邻居数量。此值必须小于或等于上面运行
preprocess_cora_dataset.py
时使用的max_nbrs
命令行参数。num_fc_units:神经网络中全连接层的数量。
train_epochs:训练轮数。
batch_size:用于训练和评估的批次大小。
dropout_rate:控制每个全连接层后的 dropout 率
eval_steps:在判定评估完成之前要处理的批次数量。如果设置为
None
,则会评估测试集中的所有实例。
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
加载训练和测试数据
如本笔记本中所述,输入训练和测试数据已由 'preprocess_cora_dataset.py' 创建。我们将它们加载到两个 tf.data.Dataset
对象中,一个用于训练,一个用于测试。
在模型的输入层中,我们将不仅提取每个样本的“词语”和“标签”特征,还会根据 hparams.num_neighbors
值提取相应的邻居特征。邻居数量少于 hparams.num_neighbors
的实例将为那些不存在的邻居特征分配虚拟值。
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
让我们看看训练数据集的内容。
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 1 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 3 6 6 4 3 1 3 4 2 5 4 5 6 4 1 5 1 0 5 6 3 0 4 2 4 4 1 1 1 6 2 2 5 3 3 5 3 2 0 0 1 5 5 0 4 6 1 4 2 0 2 4 4 1 3 2 2 2 1 2 2 5 2 2 4 1 2 6 1 6 3 0 5 2 6 4 3 2 4 0 2 1 2 2 2 2 2 2 1 1 6 3 2 4 1 2 1 0 3 0 0 3 2 6 1 2 2 1 2 2 2 3 2 0 2 3 2 5 3 0 1 1 2 0 2 1], shape=(128,), dtype=int64)
让我们看看测试数据集的内容。
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
模型定义
为了演示图正则化的使用,我们首先为这个问题构建一个基础模型。我们将使用一个简单的具有 2 个隐藏层和 dropout 的前馈神经网络。我们使用 tf.Keras
框架支持的所有模型类型(顺序、函数和子类)来说明基础模型的创建。
顺序基础模型
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
函数基础模型
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
子类基础模型
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
创建基础模型
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74607 (291.43 KB) Trainable params: 74607 (291.43 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
训练基础 MLP 模型
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:642: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 6ms/step - loss: 1.9105 - accuracy: 0.2260 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8280 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7240 - accuracy: 0.3299 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5969 - accuracy: 0.3745 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4765 - accuracy: 0.4492 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3235 - accuracy: 0.5276 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1913 - accuracy: 0.5889 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0604 - accuracy: 0.6432 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9628 - accuracy: 0.6821 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8601 - accuracy: 0.7234 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7914 - accuracy: 0.7480 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7230 - accuracy: 0.7633 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6783 - accuracy: 0.7791 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6019 - accuracy: 0.8070 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5587 - accuracy: 0.8367 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5295 - accuracy: 0.8450 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8599 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4474 - accuracy: 0.8650 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4148 - accuracy: 0.8701 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3812 - accuracy: 0.8896 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3656 - accuracy: 0.8863 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3544 - accuracy: 0.8923 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3050 - accuracy: 0.9165 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2858 - accuracy: 0.9216 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2821 - accuracy: 0.9234 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2543 - accuracy: 0.9276 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9285 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.9295 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2153 - accuracy: 0.9415 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2241 - accuracy: 0.9290 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2118 - accuracy: 0.9374 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2041 - accuracy: 0.9471 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1951 - accuracy: 0.9392 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1841 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1783 - accuracy: 0.9522 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.9485 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1705 - accuracy: 0.9541 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1507 - accuracy: 0.9592 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1513 - accuracy: 0.9555 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1378 - accuracy: 0.9652 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1471 - accuracy: 0.9587 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1309 - accuracy: 0.9661 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1288 - accuracy: 0.9596 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1327 - accuracy: 0.9629 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1170 - accuracy: 0.9675 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1198 - accuracy: 0.9666 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1183 - accuracy: 0.9680 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1025 - accuracy: 0.9740 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0981 - accuracy: 0.9754 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9708 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0874 - accuracy: 0.9796 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9735 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9740 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0934 - accuracy: 0.9759 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0932 - accuracy: 0.9759 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0787 - accuracy: 0.9810 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0890 - accuracy: 0.9754 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0918 - accuracy: 0.9749 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0908 - accuracy: 0.9717 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0825 - accuracy: 0.9777 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0926 - accuracy: 0.9684 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0702 - accuracy: 0.9800 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0720 - accuracy: 0.9842 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0792 - accuracy: 0.9773 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0760 - accuracy: 0.9782 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0736 - accuracy: 0.9800 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 0.9773 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0639 - accuracy: 0.9824 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0742 - accuracy: 0.9805 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9782 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0694 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0635 - accuracy: 0.9833 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0587 - accuracy: 0.9824 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0689 - accuracy: 0.9828 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0628 - accuracy: 0.9828 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0570 - accuracy: 0.9842 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0632 - accuracy: 0.9824 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0673 - accuracy: 0.9782 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0573 - accuracy: 0.9828 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0640 - accuracy: 0.9824 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9810 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0553 - accuracy: 0.9861 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0482 - accuracy: 0.9879 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0548 - accuracy: 0.9842 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0537 - accuracy: 0.9865 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 0.9828 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9838 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0505 - accuracy: 0.9865 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0473 - accuracy: 0.9833 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0604 - accuracy: 0.9810 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0469 - accuracy: 0.9879 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0554 - accuracy: 0.9810 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9875 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0581 - accuracy: 0.9824 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0488 - accuracy: 0.9842 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0466 - accuracy: 0.9875 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0465 - accuracy: 0.9875 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0411 - accuracy: 0.9879 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0539 - accuracy: 0.9852 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0451 - accuracy: 0.9870 <keras.src.callbacks.History at 0x7f459c2e9e50>
评估基础 MLP 模型
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4164 - accuracy: 0.7758 Eval accuracy for Base MLP model : 0.775768518447876 Eval loss for Base MLP model : 1.4164185523986816
训练具有图正则化的 MLP 模型
将图正则化合并到现有 tf.Keras.Model
的损失项中只需要几行代码。基础模型被包装以创建一个新的 tf.Keras
子类模型,其损失包括图正则化。
为了评估图正则化的增量效益,我们将创建一个新的基础模型实例。这是因为 base_model
已经训练了几次迭代,并且重新使用这个训练过的模型来创建一个图正则化模型对于 base_model
来说将不是一个公平的比较。
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 17/17 [==============================] - 2s 7ms/step - loss: 1.9586 - accuracy: 0.2107 - scaled_graph_loss: 0.0319 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8903 - accuracy: 0.2942 - scaled_graph_loss: 0.0282 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8290 - accuracy: 0.3262 - scaled_graph_loss: 0.0411 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7762 - accuracy: 0.3248 - scaled_graph_loss: 0.0604 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7334 - accuracy: 0.3568 - scaled_graph_loss: 0.0792 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6859 - accuracy: 0.3735 - scaled_graph_loss: 0.0920 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6506 - accuracy: 0.3935 - scaled_graph_loss: 0.1086 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6028 - accuracy: 0.4520 - scaled_graph_loss: 0.1249 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5690 - accuracy: 0.5012 - scaled_graph_loss: 0.1386 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5332 - accuracy: 0.5420 - scaled_graph_loss: 0.1577 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4792 - accuracy: 0.5842 - scaled_graph_loss: 0.1642 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4438 - accuracy: 0.6306 - scaled_graph_loss: 0.1909 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4155 - accuracy: 0.6617 - scaled_graph_loss: 0.2009 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3596 - accuracy: 0.6896 - scaled_graph_loss: 0.1964 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3462 - accuracy: 0.7077 - scaled_graph_loss: 0.2294 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3151 - accuracy: 0.7295 - scaled_graph_loss: 0.2312 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2848 - accuracy: 0.7555 - scaled_graph_loss: 0.2319 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2643 - accuracy: 0.7759 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2434 - accuracy: 0.7921 - scaled_graph_loss: 0.2544 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2005 - accuracy: 0.8093 - scaled_graph_loss: 0.2473 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2007 - accuracy: 0.8070 - scaled_graph_loss: 0.2688 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1876 - accuracy: 0.8135 - scaled_graph_loss: 0.2708 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1729 - accuracy: 0.8274 - scaled_graph_loss: 0.2662 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8376 - scaled_graph_loss: 0.2707 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1228 - accuracy: 0.8538 - scaled_graph_loss: 0.2677 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1166 - accuracy: 0.8603 - scaled_graph_loss: 0.2785 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1176 - accuracy: 0.8473 - scaled_graph_loss: 0.2807 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1085 - accuracy: 0.8473 - scaled_graph_loss: 0.2649 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0751 - accuracy: 0.8691 - scaled_graph_loss: 0.2858 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0851 - accuracy: 0.8696 - scaled_graph_loss: 0.2996 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0932 - accuracy: 0.8770 - scaled_graph_loss: 0.2892 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0619 - accuracy: 0.8821 - scaled_graph_loss: 0.2880 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0531 - accuracy: 0.8886 - scaled_graph_loss: 0.2847 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0558 - accuracy: 0.8863 - scaled_graph_loss: 0.2962 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0375 - accuracy: 0.8891 - scaled_graph_loss: 0.2780 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0310 - accuracy: 0.8858 - scaled_graph_loss: 0.2932 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0269 - accuracy: 0.8872 - scaled_graph_loss: 0.2916 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0273 - accuracy: 0.8928 - scaled_graph_loss: 0.2948 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9935 - accuracy: 0.9123 - scaled_graph_loss: 0.2910 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0083 - accuracy: 0.9104 - scaled_graph_loss: 0.2951 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0196 - accuracy: 0.8951 - scaled_graph_loss: 0.2982 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9941 - accuracy: 0.9007 - scaled_graph_loss: 0.2898 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0069 - accuracy: 0.9012 - scaled_graph_loss: 0.3076 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9816 - accuracy: 0.9049 - scaled_graph_loss: 0.2930 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9910 - accuracy: 0.9104 - scaled_graph_loss: 0.2954 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9949 - accuracy: 0.9026 - scaled_graph_loss: 0.3111 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9715 - accuracy: 0.9114 - scaled_graph_loss: 0.2830 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9067 - scaled_graph_loss: 0.2970 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9570 - accuracy: 0.9114 - scaled_graph_loss: 0.2936 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9691 - accuracy: 0.9049 - scaled_graph_loss: 0.2940 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9803 - accuracy: 0.9114 - scaled_graph_loss: 0.3083 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.9128 - scaled_graph_loss: 0.2860 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9627 - accuracy: 0.9216 - scaled_graph_loss: 0.3077 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9516 - accuracy: 0.9151 - scaled_graph_loss: 0.2906 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9431 - accuracy: 0.9197 - scaled_graph_loss: 0.2967 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9622 - accuracy: 0.9132 - scaled_graph_loss: 0.3053 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9410 - accuracy: 0.9188 - scaled_graph_loss: 0.2830 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9531 - accuracy: 0.9230 - scaled_graph_loss: 0.3049 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9309 - accuracy: 0.9193 - scaled_graph_loss: 0.3009 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9300 - accuracy: 0.9248 - scaled_graph_loss: 0.2988 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9173 - accuracy: 0.9244 - scaled_graph_loss: 0.2884 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9228 - accuracy: 0.9248 - scaled_graph_loss: 0.2960 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9394 - accuracy: 0.9174 - scaled_graph_loss: 0.3102 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9182 - accuracy: 0.9174 - scaled_graph_loss: 0.2899 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9276 - accuracy: 0.9253 - scaled_graph_loss: 0.2996 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9229 - accuracy: 0.9244 - scaled_graph_loss: 0.2912 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9325 - accuracy: 0.9142 - scaled_graph_loss: 0.3088 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9091 - accuracy: 0.9216 - scaled_graph_loss: 0.2883 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8987 - accuracy: 0.9267 - scaled_graph_loss: 0.2924 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9188 - accuracy: 0.9216 - scaled_graph_loss: 0.2970 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9003 - accuracy: 0.9299 - scaled_graph_loss: 0.2962 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9086 - accuracy: 0.9206 - scaled_graph_loss: 0.2944 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9047 - accuracy: 0.9304 - scaled_graph_loss: 0.3174 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9214 - accuracy: 0.9202 - scaled_graph_loss: 0.2923 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9081 - accuracy: 0.9276 - scaled_graph_loss: 0.3020 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9043 - accuracy: 0.9220 - scaled_graph_loss: 0.2892 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9022 - accuracy: 0.9253 - scaled_graph_loss: 0.2998 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8871 - accuracy: 0.9332 - scaled_graph_loss: 0.2979 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8863 - accuracy: 0.9295 - scaled_graph_loss: 0.3021 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8893 - accuracy: 0.9225 - scaled_graph_loss: 0.2928 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8850 - accuracy: 0.9258 - scaled_graph_loss: 0.2997 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9013 - accuracy: 0.9165 - scaled_graph_loss: 0.2961 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8739 - accuracy: 0.9253 - scaled_graph_loss: 0.2886 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8840 - accuracy: 0.9318 - scaled_graph_loss: 0.3040 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8628 - accuracy: 0.9378 - scaled_graph_loss: 0.2886 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8745 - accuracy: 0.9313 - scaled_graph_loss: 0.3013 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8678 - accuracy: 0.9327 - scaled_graph_loss: 0.2980 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8614 - accuracy: 0.9397 - scaled_graph_loss: 0.2947 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8589 - accuracy: 0.9327 - scaled_graph_loss: 0.2957 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8688 - accuracy: 0.9346 - scaled_graph_loss: 0.2996 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8661 - accuracy: 0.9216 - scaled_graph_loss: 0.2881 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8828 - accuracy: 0.9318 - scaled_graph_loss: 0.3019 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8701 - accuracy: 0.9374 - scaled_graph_loss: 0.3051 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8572 - accuracy: 0.9383 - scaled_graph_loss: 0.2998 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9327 - scaled_graph_loss: 0.2999 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8685 - accuracy: 0.9336 - scaled_graph_loss: 0.3013 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8710 - accuracy: 0.9378 - scaled_graph_loss: 0.3023 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8746 - accuracy: 0.9327 - scaled_graph_loss: 0.2956 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.9341 - scaled_graph_loss: 0.2984 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.9318 - scaled_graph_loss: 0.2965 <keras.src.callbacks.History at 0x7f445862f130>
评估具有图正则化的 MLP 模型
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8791 - accuracy: 0.7993 Eval accuracy for MLP + graph regularization : 0.7992766499519348 Eval loss for MLP + graph regularization : 0.8790676593780518
图正则化模型的准确率比基础模型 (base_model
) 高约 2-3%。
结论
我们已经演示了使用图正则化进行文档分类,在自然引文图 (Cora) 上使用神经结构化学习 (NSL) 框架。我们的 高级教程 涉及在训练具有图正则化的神经网络之前,根据样本嵌入合成图。如果输入不包含显式图,这种方法很有用。
我们鼓励用户通过改变监督量以及尝试不同的神经网络架构来进行图正则化,以进行进一步的实验。