![]() |
![]() |
![]() |
![]() |
图正则化是神经图学习 (Bui 等人,2018) 这一更广泛范式下的特定技术。其核心思想是使用图正则化目标训练神经网络模型,利用标记数据和未标记数据。
使用神经结构化学习 (NSL) 框架创建图正则化模型的一般方法如下
- 从输入图和样本特征生成训练数据。图中的节点对应于样本,图中的边对应于样本对之间的相似性。生成的训练数据除了原始节点特征外,还将包含邻居特征。
- 使用
顺序、函数或子类 API 创建神经网络作为基础模型。 - 使用 NSL 框架提供的
模型。此新模型将在其训练目标中包含图正则化损失作为正则化项。 - 训练和评估图
pip install --quiet neural-structured-learning
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2023-11-16 12:04:49.460421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:49.460472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:49.461916: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Version: 2.15.0 Eager mode: True GPU is NOT AVAILABLE 2023-11-16 12:04:51.768240: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Cora 数据集
Cora 数据集 是一个引文图,其中节点代表机器学习论文,边代表论文对之间的引文。所涉及的任务是文档分类,目标是将每篇论文归类到 7 个类别之一。换句话说,这是一个具有 7 个类别的多类分类问题。
原始图是有向的。但是,在本例中,我们考虑此图的无向版本。因此,如果论文 A 引用了论文 B,我们也认为论文 B 引用了论文 A。虽然这并不一定正确,但在本例中,我们将引文视为相似性的替代指标,相似性通常是可交换的属性。
输入中的每篇论文实际上包含 2 个特征
词语:论文文本的密集多热词袋表示。Cora 数据集的词汇表包含 1433 个唯一词语。因此,此特征的长度为 1433,位置“i”处的值为 0/1,表示词汇表中的词语“i”是否存在于给定论文中。
标签:表示论文的类 ID(类别)的单个整数。
下载 Cora 数据集
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
将 Cora 数据转换为 NSL 格式
为了预处理 Cora 数据集并将其转换为神经结构化学习所需格式,我们将运行 'preprocess_cora_dataset.py' 脚本,该脚本包含在 NSL github 存储库中。此脚本执行以下操作
- 使用原始节点特征和图生成邻居特征。
- 生成包含
实例的训练和测试数据拆分。 - 以
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--2023-11-16 12:04:52-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)...,,, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)||:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2023-11-16 12:04:53 (75.6 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2023-11-16 12:04:53.758687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:53.758743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:53.760530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-16 12:04:55.968449: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.44 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.05 minutes.
训练和测试数据的文件路径基于上面运行 'preprocess_cora_dataset.py' 脚本时使用的命令行标志值。
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
我们将使用 HParams
num_classes:共有 7 个不同的类别
max_seq_length:这是词汇表的大小,输入中的所有实例都具有密集多热词袋表示。换句话说,单词的值为 1 表示单词存在于输入中,值为 0 表示不存在。
dropout_rate:控制每个全连接层后的 dropout 率
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
如本笔记本中所述,输入训练和测试数据已由 'preprocess_cora_dataset.py' 创建。我们将它们加载到两个 tf.data.Dataset
在模型的输入层中,我们将不仅提取每个样本的“词语”和“标签”特征,还会根据 hparams.num_neighbors
值提取相应的邻居特征。邻居数量少于 hparams.num_neighbors
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
example_proto: An instance of `tf.train.Example`.
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 1 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 3 6 6 4 3 1 3 4 2 5 4 5 6 4 1 5 1 0 5 6 3 0 4 2 4 4 1 1 1 6 2 2 5 3 3 5 3 2 0 0 1 5 5 0 4 6 1 4 2 0 2 4 4 1 3 2 2 2 1 2 2 5 2 2 4 1 2 6 1 6 3 0 5 2 6 4 3 2 4 0 2 1 2 2 2 2 2 2 1 1 6 3 2 4 1 2 1 0 3 0 0 3 2 6 1 2 2 1 2 2 2 3 2 0 2 3 2 5 3 0 1 1 2 0 2 1], shape=(128,), dtype=int64)
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
为了演示图正则化的使用,我们首先为这个问题构建一个基础模型。我们将使用一个简单的具有 2 个隐藏层和 dropout 的前馈神经网络。我们使用 tf.Keras
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
return model
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74607 (291.43 KB) Trainable params: 74607 (291.43 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
训练基础 MLP 模型
# Compile and train the base MLP model
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:642: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 6ms/step - loss: 1.9105 - accuracy: 0.2260 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8280 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7240 - accuracy: 0.3299 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5969 - accuracy: 0.3745 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4765 - accuracy: 0.4492 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3235 - accuracy: 0.5276 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1913 - accuracy: 0.5889 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0604 - accuracy: 0.6432 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9628 - accuracy: 0.6821 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8601 - accuracy: 0.7234 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7914 - accuracy: 0.7480 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7230 - accuracy: 0.7633 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6783 - accuracy: 0.7791 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6019 - accuracy: 0.8070 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5587 - accuracy: 0.8367 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5295 - accuracy: 0.8450 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8599 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4474 - accuracy: 0.8650 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4148 - accuracy: 0.8701 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3812 - accuracy: 0.8896 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3656 - accuracy: 0.8863 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3544 - accuracy: 0.8923 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3050 - accuracy: 0.9165 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2858 - accuracy: 0.9216 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2821 - accuracy: 0.9234 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2543 - accuracy: 0.9276 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9285 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.9295 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2153 - accuracy: 0.9415 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2241 - accuracy: 0.9290 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2118 - accuracy: 0.9374 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2041 - accuracy: 0.9471 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1951 - accuracy: 0.9392 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1841 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1783 - accuracy: 0.9522 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.9485 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1705 - accuracy: 0.9541 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1507 - accuracy: 0.9592 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1513 - accuracy: 0.9555 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1378 - accuracy: 0.9652 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1471 - accuracy: 0.9587 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1309 - accuracy: 0.9661 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1288 - accuracy: 0.9596 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1327 - accuracy: 0.9629 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1170 - accuracy: 0.9675 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1198 - accuracy: 0.9666 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1183 - accuracy: 0.9680 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1025 - accuracy: 0.9740 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0981 - accuracy: 0.9754 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9708 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0874 - accuracy: 0.9796 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9735 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9740 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0934 - accuracy: 0.9759 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0932 - accuracy: 0.9759 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0787 - accuracy: 0.9810 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0890 - accuracy: 0.9754 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0918 - accuracy: 0.9749 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0908 - accuracy: 0.9717 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0825 - accuracy: 0.9777 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0926 - accuracy: 0.9684 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0702 - accuracy: 0.9800 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0720 - accuracy: 0.9842 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0792 - accuracy: 0.9773 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0760 - accuracy: 0.9782 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0736 - accuracy: 0.9800 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 0.9773 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0639 - accuracy: 0.9824 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0742 - accuracy: 0.9805 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9782 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0694 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0635 - accuracy: 0.9833 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0587 - accuracy: 0.9824 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0689 - accuracy: 0.9828 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0628 - accuracy: 0.9828 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0570 - accuracy: 0.9842 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0632 - accuracy: 0.9824 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0673 - accuracy: 0.9782 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0573 - accuracy: 0.9828 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0640 - accuracy: 0.9824 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9810 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0553 - accuracy: 0.9861 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0482 - accuracy: 0.9879 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0548 - accuracy: 0.9842 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0537 - accuracy: 0.9865 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 0.9828 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9838 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0505 - accuracy: 0.9865 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0473 - accuracy: 0.9833 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0604 - accuracy: 0.9810 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0469 - accuracy: 0.9879 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0554 - accuracy: 0.9810 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9875 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0581 - accuracy: 0.9824 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0488 - accuracy: 0.9842 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0466 - accuracy: 0.9875 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0465 - accuracy: 0.9875 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0411 - accuracy: 0.9879 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0539 - accuracy: 0.9852 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0451 - accuracy: 0.9870 <keras.src.callbacks.History at 0x7f459c2e9e50>
评估基础 MLP 模型
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4164 - accuracy: 0.7758 Eval accuracy for Base MLP model : 0.775768518447876 Eval loss for Base MLP model : 1.4164185523986816
训练具有图正则化的 MLP 模型
将图正则化合并到现有 tf.Keras.Model
的损失项中只需要几行代码。基础模型被包装以创建一个新的 tf.Keras
为了评估图正则化的增量效益,我们将创建一个新的基础模型实例。这是因为 base_model
已经训练了几次迭代,并且重新使用这个训练过的模型来创建一个图正则化模型对于 base_model
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 17/17 [==============================] - 2s 7ms/step - loss: 1.9586 - accuracy: 0.2107 - scaled_graph_loss: 0.0319 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8903 - accuracy: 0.2942 - scaled_graph_loss: 0.0282 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8290 - accuracy: 0.3262 - scaled_graph_loss: 0.0411 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7762 - accuracy: 0.3248 - scaled_graph_loss: 0.0604 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7334 - accuracy: 0.3568 - scaled_graph_loss: 0.0792 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6859 - accuracy: 0.3735 - scaled_graph_loss: 0.0920 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6506 - accuracy: 0.3935 - scaled_graph_loss: 0.1086 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6028 - accuracy: 0.4520 - scaled_graph_loss: 0.1249 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5690 - accuracy: 0.5012 - scaled_graph_loss: 0.1386 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5332 - accuracy: 0.5420 - scaled_graph_loss: 0.1577 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4792 - accuracy: 0.5842 - scaled_graph_loss: 0.1642 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4438 - accuracy: 0.6306 - scaled_graph_loss: 0.1909 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4155 - accuracy: 0.6617 - scaled_graph_loss: 0.2009 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3596 - accuracy: 0.6896 - scaled_graph_loss: 0.1964 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3462 - accuracy: 0.7077 - scaled_graph_loss: 0.2294 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3151 - accuracy: 0.7295 - scaled_graph_loss: 0.2312 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2848 - accuracy: 0.7555 - scaled_graph_loss: 0.2319 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2643 - accuracy: 0.7759 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2434 - accuracy: 0.7921 - scaled_graph_loss: 0.2544 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2005 - accuracy: 0.8093 - scaled_graph_loss: 0.2473 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2007 - accuracy: 0.8070 - scaled_graph_loss: 0.2688 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1876 - accuracy: 0.8135 - scaled_graph_loss: 0.2708 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1729 - accuracy: 0.8274 - scaled_graph_loss: 0.2662 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8376 - scaled_graph_loss: 0.2707 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1228 - accuracy: 0.8538 - scaled_graph_loss: 0.2677 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1166 - accuracy: 0.8603 - scaled_graph_loss: 0.2785 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1176 - accuracy: 0.8473 - scaled_graph_loss: 0.2807 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1085 - accuracy: 0.8473 - scaled_graph_loss: 0.2649 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0751 - accuracy: 0.8691 - scaled_graph_loss: 0.2858 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0851 - accuracy: 0.8696 - scaled_graph_loss: 0.2996 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0932 - accuracy: 0.8770 - scaled_graph_loss: 0.2892 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0619 - accuracy: 0.8821 - scaled_graph_loss: 0.2880 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0531 - accuracy: 0.8886 - scaled_graph_loss: 0.2847 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0558 - accuracy: 0.8863 - scaled_graph_loss: 0.2962 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0375 - accuracy: 0.8891 - scaled_graph_loss: 0.2780 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0310 - accuracy: 0.8858 - scaled_graph_loss: 0.2932 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0269 - accuracy: 0.8872 - scaled_graph_loss: 0.2916 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0273 - accuracy: 0.8928 - scaled_graph_loss: 0.2948 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9935 - accuracy: 0.9123 - scaled_graph_loss: 0.2910 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0083 - accuracy: 0.9104 - scaled_graph_loss: 0.2951 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0196 - accuracy: 0.8951 - scaled_graph_loss: 0.2982 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9941 - accuracy: 0.9007 - scaled_graph_loss: 0.2898 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0069 - accuracy: 0.9012 - scaled_graph_loss: 0.3076 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9816 - accuracy: 0.9049 - scaled_graph_loss: 0.2930 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9910 - accuracy: 0.9104 - scaled_graph_loss: 0.2954 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9949 - accuracy: 0.9026 - scaled_graph_loss: 0.3111 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9715 - accuracy: 0.9114 - scaled_graph_loss: 0.2830 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9067 - scaled_graph_loss: 0.2970 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9570 - accuracy: 0.9114 - scaled_graph_loss: 0.2936 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9691 - accuracy: 0.9049 - scaled_graph_loss: 0.2940 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9803 - accuracy: 0.9114 - scaled_graph_loss: 0.3083 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.9128 - scaled_graph_loss: 0.2860 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9627 - accuracy: 0.9216 - scaled_graph_loss: 0.3077 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9516 - accuracy: 0.9151 - scaled_graph_loss: 0.2906 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9431 - accuracy: 0.9197 - scaled_graph_loss: 0.2967 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9622 - accuracy: 0.9132 - scaled_graph_loss: 0.3053 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9410 - accuracy: 0.9188 - scaled_graph_loss: 0.2830 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9531 - accuracy: 0.9230 - scaled_graph_loss: 0.3049 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9309 - accuracy: 0.9193 - scaled_graph_loss: 0.3009 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9300 - accuracy: 0.9248 - scaled_graph_loss: 0.2988 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9173 - accuracy: 0.9244 - scaled_graph_loss: 0.2884 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9228 - accuracy: 0.9248 - scaled_graph_loss: 0.2960 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9394 - accuracy: 0.9174 - scaled_graph_loss: 0.3102 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9182 - accuracy: 0.9174 - scaled_graph_loss: 0.2899 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9276 - accuracy: 0.9253 - scaled_graph_loss: 0.2996 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9229 - accuracy: 0.9244 - scaled_graph_loss: 0.2912 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9325 - accuracy: 0.9142 - scaled_graph_loss: 0.3088 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9091 - accuracy: 0.9216 - scaled_graph_loss: 0.2883 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8987 - accuracy: 0.9267 - scaled_graph_loss: 0.2924 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9188 - accuracy: 0.9216 - scaled_graph_loss: 0.2970 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9003 - accuracy: 0.9299 - scaled_graph_loss: 0.2962 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9086 - accuracy: 0.9206 - scaled_graph_loss: 0.2944 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9047 - accuracy: 0.9304 - scaled_graph_loss: 0.3174 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9214 - accuracy: 0.9202 - scaled_graph_loss: 0.2923 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9081 - accuracy: 0.9276 - scaled_graph_loss: 0.3020 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9043 - accuracy: 0.9220 - scaled_graph_loss: 0.2892 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9022 - accuracy: 0.9253 - scaled_graph_loss: 0.2998 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8871 - accuracy: 0.9332 - scaled_graph_loss: 0.2979 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8863 - accuracy: 0.9295 - scaled_graph_loss: 0.3021 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8893 - accuracy: 0.9225 - scaled_graph_loss: 0.2928 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8850 - accuracy: 0.9258 - scaled_graph_loss: 0.2997 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9013 - accuracy: 0.9165 - scaled_graph_loss: 0.2961 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8739 - accuracy: 0.9253 - scaled_graph_loss: 0.2886 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8840 - accuracy: 0.9318 - scaled_graph_loss: 0.3040 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8628 - accuracy: 0.9378 - scaled_graph_loss: 0.2886 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8745 - accuracy: 0.9313 - scaled_graph_loss: 0.3013 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8678 - accuracy: 0.9327 - scaled_graph_loss: 0.2980 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8614 - accuracy: 0.9397 - scaled_graph_loss: 0.2947 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8589 - accuracy: 0.9327 - scaled_graph_loss: 0.2957 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8688 - accuracy: 0.9346 - scaled_graph_loss: 0.2996 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8661 - accuracy: 0.9216 - scaled_graph_loss: 0.2881 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8828 - accuracy: 0.9318 - scaled_graph_loss: 0.3019 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8701 - accuracy: 0.9374 - scaled_graph_loss: 0.3051 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8572 - accuracy: 0.9383 - scaled_graph_loss: 0.2998 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9327 - scaled_graph_loss: 0.2999 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8685 - accuracy: 0.9336 - scaled_graph_loss: 0.3013 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8710 - accuracy: 0.9378 - scaled_graph_loss: 0.3023 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8746 - accuracy: 0.9327 - scaled_graph_loss: 0.2956 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.9341 - scaled_graph_loss: 0.2984 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.9318 - scaled_graph_loss: 0.2965 <keras.src.callbacks.History at 0x7f445862f130>
评估具有图正则化的 MLP 模型
eval_results = dict(
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8791 - accuracy: 0.7993 Eval accuracy for MLP + graph regularization : 0.7992766499519348 Eval loss for MLP + graph regularization : 0.8790676593780518
图正则化模型的准确率比基础模型 (base_model
) 高约 2-3%。
我们已经演示了使用图正则化进行文档分类,在自然引文图 (Cora) 上使用神经结构化学习 (NSL) 框架。我们的 高级教程 涉及在训练具有图正则化的神经网络之前,根据样本嵌入合成图。如果输入不包含显式图,这种方法很有用。