使用自然图进行文档分类的图正则化

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

图正则化是神经图学习 (Bui 等人,2018) 这一更广泛范式下的特定技术。其核心思想是使用图正则化目标训练神经网络模型,利用标记数据和未标记数据。

在本教程中,我们将探讨使用图正则化对形成自然(有机)图的文档进行分类。

使用神经结构化学习 (NSL) 框架创建图正则化模型的一般方法如下

  1. 从输入图和样本特征生成训练数据。图中的节点对应于样本,图中的边对应于样本对之间的相似性。生成的训练数据除了原始节点特征外,还将包含邻居特征。
  2. 使用 Keras 顺序、函数或子类 API 创建神经网络作为基础模型。
  3. 使用 NSL 框架提供的 GraphRegularization 包装类包装基础模型,以创建一个新的图 Keras 模型。此新模型将在其训练目标中包含图正则化损失作为正则化项。
  4. 训练和评估图 Keras 模型。

设置

安装神经结构化学习包。

pip install --quiet neural-structured-learning

依赖项和导入

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2023-11-16 12:04:49.460421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-16 12:04:49.460472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-16 12:04:49.461916: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Version:  2.15.0
Eager mode:  True
GPU is NOT AVAILABLE
2023-11-16 12:04:51.768240: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Cora 数据集

Cora 数据集 是一个引文图,其中节点代表机器学习论文,边代表论文对之间的引文。所涉及的任务是文档分类,目标是将每篇论文归类到 7 个类别之一。换句话说,这是一个具有 7 个类别的多类分类问题。

原始图是有向的。但是,在本例中,我们考虑此图的无向版本。因此,如果论文 A 引用了论文 B,我们也认为论文 B 引用了论文 A。虽然这并不一定正确,但在本例中,我们将引文视为相似性的替代指标,相似性通常是可交换的属性。

特征

输入中的每篇论文实际上包含 2 个特征

  1. 词语:论文文本的密集多热词袋表示。Cora 数据集的词汇表包含 1433 个唯一词语。因此,此特征的长度为 1433,位置“i”处的值为 0/1,表示词汇表中的词语“i”是否存在于给定论文中。

  2. 标签:表示论文的类 ID(类别)的单个整数。

下载 Cora 数据集

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

将 Cora 数据转换为 NSL 格式

为了预处理 Cora 数据集并将其转换为神经结构化学习所需格式,我们将运行 'preprocess_cora_dataset.py' 脚本,该脚本包含在 NSL github 存储库中。此脚本执行以下操作

  1. 使用原始节点特征和图生成邻居特征。
  2. 生成包含 tf.train.Example 实例的训练和测试数据拆分。
  3. TFRecord 格式持久化生成的训练和测试数据。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2023-11-16 12:04:52--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2023-11-16 12:04:53 (75.6 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2023-11-16 12:04:53.758687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-16 12:04:53.758743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-16 12:04:53.760530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-16 12:04:55.968449: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.44 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.05 minutes.

全局变量

训练和测试数据的文件路径基于上面运行 'preprocess_cora_dataset.py' 脚本时使用的命令行标志值。

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

超参数

我们将使用 HParams 的实例来包含用于训练和评估的各种超参数和常量。我们将在下面简要描述每个参数

  • num_classes:共有 7 个不同的类别

  • max_seq_length:这是词汇表的大小,输入中的所有实例都具有密集多热词袋表示。换句话说,单词的值为 1 表示单词存在于输入中,值为 0 表示不存在。

  • distance_type:这是用于对样本及其邻居进行正则化的距离度量。

  • graph_regularization_multiplier:这控制了图正则化项在整体损失函数中的相对权重。

  • num_neighbors:用于图正则化的邻居数量。此值必须小于或等于上面运行 preprocess_cora_dataset.py 时使用的 max_nbrs 命令行参数。

  • num_fc_units:神经网络中全连接层的数量。

  • train_epochs:训练轮数。

  • batch_size:用于训练和评估的批次大小。

  • dropout_rate:控制每个全连接层后的 dropout 率

  • eval_steps:在判定评估完成之前要处理的批次数量。如果设置为 None,则会评估测试集中的所有实例。

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

加载训练和测试数据

如本笔记本中所述,输入训练和测试数据已由 'preprocess_cora_dataset.py' 创建。我们将它们加载到两个 tf.data.Dataset 对象中,一个用于训练,一个用于测试。

在模型的输入层中,我们将不仅提取每个样本的“词语”和“标签”特征,还会根据 hparams.num_neighbors 值提取相应的邻居特征。邻居数量少于 hparams.num_neighbors 的实例将为那些不存在的邻居特征分配虚拟值。

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

让我们看看训练数据集的内容。

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 3 6 6 4 3 1 3 4 2 5 4 5 6 4 1 5 1 0 5 6 3 0 4 2 4 4 1 1 1 6 2 2 5 3 3
 5 3 2 0 0 1 5 5 0 4 6 1 4 2 0 2 4 4 1 3 2 2 2 1 2 2 5 2 2 4 1 2 6 1 6 3 0
 5 2 6 4 3 2 4 0 2 1 2 2 2 2 2 2 1 1 6 3 2 4 1 2 1 0 3 0 0 3 2 6 1 2 2 1 2
 2 2 3 2 0 2 3 2 5 3 0 1 1 2 0 2 1], shape=(128,), dtype=int64)

让我们看看测试数据集的内容。

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

模型定义

为了演示图正则化的使用,我们首先为这个问题构建一个基础模型。我们将使用一个简单的具有 2 个隐藏层和 dropout 的前馈神经网络。我们使用 tf.Keras 框架支持的所有模型类型(顺序、函数和子类)来说明基础模型的创建。

顺序基础模型

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

函数基础模型

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

子类基础模型

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

创建基础模型

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 words (InputLayer)          [(None, 1433)]            0         
                                                                 
 lambda (Lambda)             (None, 1433)              0         
                                                                 
 dense (Dense)               (None, 50)                71700     
                                                                 
 dropout (Dropout)           (None, 50)                0         
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 7)                 357       
                                                                 
=================================================================
Total params: 74607 (291.43 KB)
Trainable params: 74607 (291.43 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

训练基础 MLP 模型

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:642: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
17/17 [==============================] - 1s 6ms/step - loss: 1.9105 - accuracy: 0.2260
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8280 - accuracy: 0.3044
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7240 - accuracy: 0.3299
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5969 - accuracy: 0.3745
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4765 - accuracy: 0.4492
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3235 - accuracy: 0.5276
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1913 - accuracy: 0.5889
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0604 - accuracy: 0.6432
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9628 - accuracy: 0.6821
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8601 - accuracy: 0.7234
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7914 - accuracy: 0.7480
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7230 - accuracy: 0.7633
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6783 - accuracy: 0.7791
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6019 - accuracy: 0.8070
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5587 - accuracy: 0.8367
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5295 - accuracy: 0.8450
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8599
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4474 - accuracy: 0.8650
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4148 - accuracy: 0.8701
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3812 - accuracy: 0.8896
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3656 - accuracy: 0.8863
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3544 - accuracy: 0.8923
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3050 - accuracy: 0.9165
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2858 - accuracy: 0.9216
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2821 - accuracy: 0.9234
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2543 - accuracy: 0.9276
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9285
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.9295
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2153 - accuracy: 0.9415
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2241 - accuracy: 0.9290
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2118 - accuracy: 0.9374
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2041 - accuracy: 0.9471
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1951 - accuracy: 0.9392
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1841 - accuracy: 0.9443
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1783 - accuracy: 0.9522
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.9485
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1705 - accuracy: 0.9541
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1507 - accuracy: 0.9592
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1513 - accuracy: 0.9555
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1378 - accuracy: 0.9652
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1471 - accuracy: 0.9587
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1309 - accuracy: 0.9661
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1288 - accuracy: 0.9596
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1327 - accuracy: 0.9629
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1170 - accuracy: 0.9675
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1198 - accuracy: 0.9666
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1183 - accuracy: 0.9680
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1025 - accuracy: 0.9740
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0981 - accuracy: 0.9754
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9708
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0874 - accuracy: 0.9796
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9735
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9740
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0934 - accuracy: 0.9759
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0932 - accuracy: 0.9759
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0787 - accuracy: 0.9810
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0890 - accuracy: 0.9754
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0918 - accuracy: 0.9749
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0908 - accuracy: 0.9717
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0825 - accuracy: 0.9777
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0926 - accuracy: 0.9684
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0702 - accuracy: 0.9800
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0720 - accuracy: 0.9842
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0792 - accuracy: 0.9773
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0760 - accuracy: 0.9782
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0736 - accuracy: 0.9800
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 0.9773
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0639 - accuracy: 0.9824
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0742 - accuracy: 0.9805
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9782
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0694 - accuracy: 0.9805
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0635 - accuracy: 0.9833
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0587 - accuracy: 0.9824
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0689 - accuracy: 0.9828
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0628 - accuracy: 0.9828
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0570 - accuracy: 0.9842
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0632 - accuracy: 0.9824
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0673 - accuracy: 0.9782
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0573 - accuracy: 0.9828
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0640 - accuracy: 0.9824
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9810
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0553 - accuracy: 0.9861
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0482 - accuracy: 0.9879
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0548 - accuracy: 0.9842
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0537 - accuracy: 0.9865
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 0.9828
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9838
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0505 - accuracy: 0.9865
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0473 - accuracy: 0.9833
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0604 - accuracy: 0.9810
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0469 - accuracy: 0.9879
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0554 - accuracy: 0.9810
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9875
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0581 - accuracy: 0.9824
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0488 - accuracy: 0.9842
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0466 - accuracy: 0.9875
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0465 - accuracy: 0.9875
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0411 - accuracy: 0.9879
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0539 - accuracy: 0.9852
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0451 - accuracy: 0.9870
<keras.src.callbacks.History at 0x7f459c2e9e50>

评估基础 MLP 模型

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4164 - accuracy: 0.7758


Eval accuracy for  Base MLP model :  0.775768518447876
Eval loss for  Base MLP model :  1.4164185523986816

训练具有图正则化的 MLP 模型

将图正则化合并到现有 tf.Keras.Model 的损失项中只需要几行代码。基础模型被包装以创建一个新的 tf.Keras 子类模型,其损失包括图正则化。

为了评估图正则化的增量效益,我们将创建一个新的基础模型实例。这是因为 base_model 已经训练了几次迭代,并且重新使用这个训练过的模型来创建一个图正则化模型对于 base_model 来说将不是一个公平的比较。

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
17/17 [==============================] - 2s 7ms/step - loss: 1.9586 - accuracy: 0.2107 - scaled_graph_loss: 0.0319
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8903 - accuracy: 0.2942 - scaled_graph_loss: 0.0282
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8290 - accuracy: 0.3262 - scaled_graph_loss: 0.0411
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7762 - accuracy: 0.3248 - scaled_graph_loss: 0.0604
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7334 - accuracy: 0.3568 - scaled_graph_loss: 0.0792
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6859 - accuracy: 0.3735 - scaled_graph_loss: 0.0920
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6506 - accuracy: 0.3935 - scaled_graph_loss: 0.1086
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6028 - accuracy: 0.4520 - scaled_graph_loss: 0.1249
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5690 - accuracy: 0.5012 - scaled_graph_loss: 0.1386
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5332 - accuracy: 0.5420 - scaled_graph_loss: 0.1577
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4792 - accuracy: 0.5842 - scaled_graph_loss: 0.1642
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4438 - accuracy: 0.6306 - scaled_graph_loss: 0.1909
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4155 - accuracy: 0.6617 - scaled_graph_loss: 0.2009
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3596 - accuracy: 0.6896 - scaled_graph_loss: 0.1964
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3462 - accuracy: 0.7077 - scaled_graph_loss: 0.2294
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3151 - accuracy: 0.7295 - scaled_graph_loss: 0.2312
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2848 - accuracy: 0.7555 - scaled_graph_loss: 0.2319
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2643 - accuracy: 0.7759 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2434 - accuracy: 0.7921 - scaled_graph_loss: 0.2544
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2005 - accuracy: 0.8093 - scaled_graph_loss: 0.2473
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2007 - accuracy: 0.8070 - scaled_graph_loss: 0.2688
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1876 - accuracy: 0.8135 - scaled_graph_loss: 0.2708
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1729 - accuracy: 0.8274 - scaled_graph_loss: 0.2662
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8376 - scaled_graph_loss: 0.2707
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1228 - accuracy: 0.8538 - scaled_graph_loss: 0.2677
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1166 - accuracy: 0.8603 - scaled_graph_loss: 0.2785
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1176 - accuracy: 0.8473 - scaled_graph_loss: 0.2807
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1085 - accuracy: 0.8473 - scaled_graph_loss: 0.2649
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0751 - accuracy: 0.8691 - scaled_graph_loss: 0.2858
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0851 - accuracy: 0.8696 - scaled_graph_loss: 0.2996
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0932 - accuracy: 0.8770 - scaled_graph_loss: 0.2892
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0619 - accuracy: 0.8821 - scaled_graph_loss: 0.2880
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0531 - accuracy: 0.8886 - scaled_graph_loss: 0.2847
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0558 - accuracy: 0.8863 - scaled_graph_loss: 0.2962
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0375 - accuracy: 0.8891 - scaled_graph_loss: 0.2780
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0310 - accuracy: 0.8858 - scaled_graph_loss: 0.2932
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0269 - accuracy: 0.8872 - scaled_graph_loss: 0.2916
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0273 - accuracy: 0.8928 - scaled_graph_loss: 0.2948
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9935 - accuracy: 0.9123 - scaled_graph_loss: 0.2910
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0083 - accuracy: 0.9104 - scaled_graph_loss: 0.2951
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0196 - accuracy: 0.8951 - scaled_graph_loss: 0.2982
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9941 - accuracy: 0.9007 - scaled_graph_loss: 0.2898
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0069 - accuracy: 0.9012 - scaled_graph_loss: 0.3076
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9816 - accuracy: 0.9049 - scaled_graph_loss: 0.2930
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9910 - accuracy: 0.9104 - scaled_graph_loss: 0.2954
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9949 - accuracy: 0.9026 - scaled_graph_loss: 0.3111
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9715 - accuracy: 0.9114 - scaled_graph_loss: 0.2830
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9067 - scaled_graph_loss: 0.2970
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9570 - accuracy: 0.9114 - scaled_graph_loss: 0.2936
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9691 - accuracy: 0.9049 - scaled_graph_loss: 0.2940
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9803 - accuracy: 0.9114 - scaled_graph_loss: 0.3083
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.9128 - scaled_graph_loss: 0.2860
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9627 - accuracy: 0.9216 - scaled_graph_loss: 0.3077
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9516 - accuracy: 0.9151 - scaled_graph_loss: 0.2906
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9431 - accuracy: 0.9197 - scaled_graph_loss: 0.2967
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9622 - accuracy: 0.9132 - scaled_graph_loss: 0.3053
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9410 - accuracy: 0.9188 - scaled_graph_loss: 0.2830
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9531 - accuracy: 0.9230 - scaled_graph_loss: 0.3049
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9309 - accuracy: 0.9193 - scaled_graph_loss: 0.3009
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9300 - accuracy: 0.9248 - scaled_graph_loss: 0.2988
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9173 - accuracy: 0.9244 - scaled_graph_loss: 0.2884
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9228 - accuracy: 0.9248 - scaled_graph_loss: 0.2960
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9394 - accuracy: 0.9174 - scaled_graph_loss: 0.3102
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9182 - accuracy: 0.9174 - scaled_graph_loss: 0.2899
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9276 - accuracy: 0.9253 - scaled_graph_loss: 0.2996
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9229 - accuracy: 0.9244 - scaled_graph_loss: 0.2912
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9325 - accuracy: 0.9142 - scaled_graph_loss: 0.3088
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9091 - accuracy: 0.9216 - scaled_graph_loss: 0.2883
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8987 - accuracy: 0.9267 - scaled_graph_loss: 0.2924
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9188 - accuracy: 0.9216 - scaled_graph_loss: 0.2970
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9003 - accuracy: 0.9299 - scaled_graph_loss: 0.2962
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9086 - accuracy: 0.9206 - scaled_graph_loss: 0.2944
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9047 - accuracy: 0.9304 - scaled_graph_loss: 0.3174
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9214 - accuracy: 0.9202 - scaled_graph_loss: 0.2923
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9081 - accuracy: 0.9276 - scaled_graph_loss: 0.3020
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9043 - accuracy: 0.9220 - scaled_graph_loss: 0.2892
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9022 - accuracy: 0.9253 - scaled_graph_loss: 0.2998
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8871 - accuracy: 0.9332 - scaled_graph_loss: 0.2979
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8863 - accuracy: 0.9295 - scaled_graph_loss: 0.3021
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8893 - accuracy: 0.9225 - scaled_graph_loss: 0.2928
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8850 - accuracy: 0.9258 - scaled_graph_loss: 0.2997
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9013 - accuracy: 0.9165 - scaled_graph_loss: 0.2961
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8739 - accuracy: 0.9253 - scaled_graph_loss: 0.2886
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8840 - accuracy: 0.9318 - scaled_graph_loss: 0.3040
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8628 - accuracy: 0.9378 - scaled_graph_loss: 0.2886
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8745 - accuracy: 0.9313 - scaled_graph_loss: 0.3013
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8678 - accuracy: 0.9327 - scaled_graph_loss: 0.2980
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8614 - accuracy: 0.9397 - scaled_graph_loss: 0.2947
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8589 - accuracy: 0.9327 - scaled_graph_loss: 0.2957
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8688 - accuracy: 0.9346 - scaled_graph_loss: 0.2996
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8661 - accuracy: 0.9216 - scaled_graph_loss: 0.2881
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8828 - accuracy: 0.9318 - scaled_graph_loss: 0.3019
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8701 - accuracy: 0.9374 - scaled_graph_loss: 0.3051
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8572 - accuracy: 0.9383 - scaled_graph_loss: 0.2998
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9327 - scaled_graph_loss: 0.2999
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8685 - accuracy: 0.9336 - scaled_graph_loss: 0.3013
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8710 - accuracy: 0.9378 - scaled_graph_loss: 0.3023
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8746 - accuracy: 0.9327 - scaled_graph_loss: 0.2956
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.9341 - scaled_graph_loss: 0.2984
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.9318 - scaled_graph_loss: 0.2965
<keras.src.callbacks.History at 0x7f445862f130>

评估具有图正则化的 MLP 模型

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8791 - accuracy: 0.7993


Eval accuracy for  MLP + graph regularization :  0.7992766499519348
Eval loss for  MLP + graph regularization :  0.8790676593780518

图正则化模型的准确率比基础模型 (base_model) 高约 2-3%。

结论

我们已经演示了使用图正则化进行文档分类,在自然引文图 (Cora) 上使用神经结构化学习 (NSL) 框架。我们的 高级教程 涉及在训练具有图正则化的神经网络之前,根据样本嵌入合成图。如果输入不包含显式图,这种方法很有用。

我们鼓励用户通过改变监督量以及尝试不同的神经网络架构来进行图正则化,以进行进一步的实验。