创建自定义反事实逻辑配对数据集

将反事实逻辑配对 (CLP) 应用于评估和改进模型的公平性需要一个反事实数据集。您可以通过复制现有数据集并更改新数据集以添加、删除或修改身份术语来创建反事实数据集。本教程解释了为现有文本数据集创建反事实数据集的方法和技术。

您可以通过创建一个新的数据对象 CounterfactualPackedInputs 来使用 CLP 技术与反事实数据集,该对象包含 original_inputcounterfactual_data,如下所示

CounterfactualPackedInputs 如下所示

CounterfactualPackedInputs(
  original_input=(x, y, sample_weight),
  counterfactual_data=(original_x, counterfactual_x,
                       counterfactual_sample_weight)
)

original_input 应该是用于训练 Keras 模型的原始数据集。 counterfactual_data 应该是一个 tf.data.Dataset,其中包含原始 x 值、相应的 counterfactual_x 值和 counterfactual_sample_weightcounterfactual_x 值与原始值几乎相同,但删除或替换了一个或多个属性。此数据集用于将原始值和反事实值之间的损失函数配对,目的是确保当敏感属性不同时,模型的预测不会改变。 original_inputcounterfactual_data 需要具有相同的形状。您可以从 counterfactual_data 中复制值,使其与 original_input 的元素数量相同。

counterfactual_data 的属性

  • 所有 original_x 值都需要引用身份组
  • 每个 counterfactual_x 值都与原始值相同,但删除或替换了一个或多个属性
  • 与原始输入具有相同的形状(您可以复制值,使其具有相同的形状)

counterfactual_data 不需要

  • 与原始输入中的数据重叠
  • 具有真实标签

以下是一个 counterfactual_data 的示例,如果您删除了“gay”一词。

original_x: “I am a gay man”
counterfactual_x: “I am a man” 
counterfactual_sample_weight”: 1

如果您有一个文本分类器,您可以使用 build_counterfactual_data 来帮助创建反事实数据集。对于所有其他数据类型,您需要直接提供反事实数据集。

设置

您将从安装 TensorFlow 模型修复开始。

pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
from tensorflow_model_remediation import counterfactual

创建一个简单的 Dataset

为了演示目的,我们将使用 build_counterfactual_dataset 从原始输入创建反事实数据。请注意,您也可以从未标记的数据构建反事实数据(而不是从原始输入构建)。您将创建一个包含一个句子的简单数据集:“i am a gay man”,它将用作 original_input

构建反事实数据集

由于这是一个文本分类器,您可以使用 build_counterfactual_data 通过两种方式创建反事实数据集

  1. 删除术语:使用 build_counterfactual_data 传递一个词语列表,这些词语将通过 tf.strings.regex_replace 从数据集中删除。
  2. 替换术语:创建一个自定义函数以传递给 build_counterfactual_data。这可能包括使用更具体的正则表达式函数来替换原始数据集中词语,或支持非文本特征

build_counterfactual_dataset 接收 original_input,并根据您传递的可选参数删除或替换术语。在大多数情况下,删除术语(选项 1)应该足以运行 CLP,但是传递自定义函数(选项 2)可以更精确地控制反事实值。

选项 1:要删除的词语列表

使用 build_counterfactual_data 传递一个包含要删除的性别相关术语的列表。

当使用简单的正则表达式创建反事实数据集时,请记住这可能会增强不应该更改的词语。最好检查对 counterfactual_x 值所做的更改在 orginal_x 值的上下文中是否有意义。此外,build_counterfactual_dataset 将仅返回包含反事实实例的值。这可能会导致与 orginal_input 形状不同的数据集,但在传递给 pack_counterfactual_data 时将调整大小。

simple_dataset_x = tf.constant(
    ["I am a gay man" + str(i) for i in range(10)] +
    ["I am a man" + str(i) for i in range(10)])
print("Length of starting values: " + str(len(simple_dataset_x)))

simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'])

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))

选项 2:自定义函数

为了更灵活地修改原始数据集,您可以将自定义函数传递给 build_counterfactual_data

在示例中,您可以考虑将引用男性的身份术语替换为引用女性的术语。这可以通过编写一个函数来替换词语字典来完成。

请注意,自定义函数的唯一限制是它必须是可调用的,以接受和返回在 Model.fit 中使用的格式的元组,并且应该删除不包含任何更改的值,这可以通过将术语传递给 sensitive_terms_to_remove 来完成。

words_to_replace = {"man": "woman"}
print("Length of starting values: " + str(len(simple_dataset_x)))

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for word in words_to_replace:
    counterfactual_x = tf.strings.regex_replace(
        original_x, f'{word}', words_to_replace[word])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, counterfactual_x, sample_weight=original_sample_weight)

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'],
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))

要了解更多信息,请参阅 build_counterfactual_data 的 API 文档。