使用决策森林进行排序学习

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本

欢迎使用 **TensorFlow 决策森林** (**TF-DF**) 的 **排序学习 Colab**。在这个 Colab 中,您将学习如何使用 **TF-DF** 进行排序。

此 Colab 假设您熟悉 初学者 Colab 中介绍的概念,特别是关于 TF-DF 的安装。

在这个 Colab 中,您将

  1. 了解什么是排序模型。
  2. 在 LETOR3 数据集上训练梯度提升树模型。
  3. 评估此模型的质量。

安装 TensorFlow 决策森林

通过运行以下单元格来安装 TF-DF。

pip install tensorflow_decision_forests

Wurlitzer 是在 Colab 中显示详细训练日志(在模型构造函数中使用 verbose=2 时)所需的。

pip install wurlitzer

导入库

import os
# Keep using Keras 2
os.environ['TF_USE_LEGACY_KERAS'] = '1'

import tensorflow_decision_forests as tfdf

import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras
import math

隐藏的代码单元格限制了 Colab 中的输出高度。

# Check the version of TensorFlow Decision Forests
print("Found TensorFlow Decision Forests v" + tfdf.__version__)
Found TensorFlow Decision Forests v1.9.0

什么是排序模型?

排序模型的目标是 **正确排序** 项目。例如,排序可用于在用户 **查询** 后选择要检索的最佳 **文档**。

表示排序数据集的一种常见方法是使用“相关性”分数:元素的顺序由其相关性定义:相关性更高的项目应位于相关性更低的项目之前。错误的成本由预测项目与正确项目的相关性之差定义。例如,对相关性分别为 3 和 4 的两个项目进行错误排序,并不像对相关性分别为 1 和 5 的两个项目进行错误排序那样糟糕。

TF-DF 期望排序数据集以“扁平”格式呈现。查询和相应文档的数据集可能如下所示

查询 文档 ID 特征 1 特征 2 相关性
1 0.1 蓝色 4
2 0.5 绿色 1
3 0.2 红色 2
4 NA 红色 0
5 0.2 红色 0
6 0.6 绿色 1

相关性/标签 是一个介于 0 到 5 之间的浮点数值(通常介于 0 到 4 之间),其中 0 表示“完全不相关”,4 表示“非常相关”,5 表示“与查询相同”。

在此示例中,文档 1 与查询“猫”非常相关,而文档 2 仅与猫“相关”。没有文档真正谈论“狗”(文档 6 的最高相关性为 1)。但是,狗查询仍然期望返回文档 6(因为这是最“多”谈论狗的文档)。

有趣的是,决策森林通常是优秀的排序器,许多最先进的排序模型都是决策森林。

让我们训练一个排序模型

在此示例中,使用 LETOR3 数据集的样本。更准确地说,我们想要从 LETOR3 仓库 下载 OHSUMED.zip。此数据集以 libsvm 格式存储,因此我们需要将其转换为 csv。

archive_path = tf_keras.utils.get_file("letor.zip",
  "https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip",
  extract=True)

# Path to a ranking ataset using libsvm format.
raw_dataset_path = os.path.join(os.path.dirname(archive_path),"OHSUMED/Data/Fold1/trainingset.txt")
Downloading data from https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip
61824018/61824018 [==============================] - 7s 0us/step

以下是数据集的前几行

head {raw_dataset_path}


第一步是将此数据集转换为上面提到的“扁平”格式。

def convert_libsvm_to_csv(src_path, dst_path):
  """Converts a libsvm ranking dataset into a flat csv file.

  Note: This code is specific to the LETOR3 dataset.
  """
  dst_handle = open(dst_path, "w")
  first_line = True
  for src_line in open(src_path,"r"):
    # Note: The last 3 items are comments.
    items = src_line.split(" ")[:-3]
    relevance = items[0]
    group = items[1].split(":")[1]
    features = [ item.split(":") for item in items[2:]]

    if first_line:
      # Csv header
      dst_handle.write("relevance,group," + ",".join(["f_" + feature[0] for feature in features]) + "\n")
      first_line = False
    dst_handle.write(relevance + ",g_" + group + "," + (",".join([feature[1] for feature in features])) + "\n")
  dst_handle.close()

# Convert the dataset.
csv_dataset_path="/tmp/ohsumed.csv"
convert_libsvm_to_csv(raw_dataset_path, csv_dataset_path)

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv(csv_dataset_path)

# Display the first 3 examples.
dataset_df.head(3)

在此数据集中,每行代表一个查询/文档对(称为“组”)。“相关性”表示查询与文档匹配的程度。

查询和文档的特征在“f1-25”中合并在一起。特征的确切定义未知,但它可能类似于

  • 查询中的单词数
  • 查询和文档之间共同单词的数量
  • 查询嵌入与文档嵌入之间的余弦相似度。
  • ...

让我们将 Pandas 数据帧转换为 TensorFlow 数据集

dataset_ds = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="relevance", task=tfdf.keras.Task.RANKING)

让我们配置和训练我们的排序模型。

%set_cell_height 400

model = tfdf.keras.GradientBoostedTreesModel(
    task=tfdf.keras.Task.RANKING,
    ranking_group="group",
    num_trees=50)

model.fit(dataset_ds)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpzqjjgty3 as temporary training directory
Reading training dataset...
[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1840] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1851] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1865] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
Training dataset read in 0:00:03.986733. Found 9219 examples.
Training model...
Model trained in 0:00:00.757738
Compiling model...
[INFO 24-04-20 11:09:19.2736 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpzqjjgty3/model/ with prefix fa7585ffd7c24e56
[INFO 24-04-20 11:09:19.2748 UTC quick_scorer_extended.cc:911] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference.
[INFO 24-04-20 11:09:19.2749 UTC abstract_model.cc:1344] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO 24-04-20 11:09:19.2749 UTC kernel.cc:1061] Use fast generic engine
Model compiled.
<tf_keras.src.callbacks.History at 0x7fb6979cc8b0>

现在我们可以查看模型在验证数据集上的质量。默认情况下,TF-DF 训练排序模型以优化 NDCG。NDCG 是一个介于 0 到 1 之间的数值,其中 1 是完美分数。因此,-NDCG 是模型损失。

import matplotlib.pyplot as plt

logs = model.make_inspector().training_logs()

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot([log.num_trees for log in logs], [log.evaluation.ndcg for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("NDCG (validation)")

plt.subplot(1, 2, 2)
plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
plt.xlabel("Number of trees")
plt.ylabel("Loss (validation)")

plt.show()

png

与所有 TF-DF 模型一样,您也可以查看模型报告(注意:模型报告还包含训练日志)

%set_cell_height 400
model.summary()
<IPython.core.display.Javascript object>
Model: "gradient_boosted_trees_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: RANKING
Label: "__LABEL"
Rank group: "group"

Input Features (25):
    f_1
    f_10
    f_11
    f_12
    f_13
    f_14
    f_15
    f_16
    f_17
    f_18
    f_19
    f_2
    f_20
    f_21
    f_22
    f_23
    f_24
    f_25
    f_3
    f_4
    f_5
    f_6
    f_7
    f_8
    f_9

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:

    1.  "f_9"  0.326164 ################
    2.  "f_3"  0.318071 ###############
    3.  "f_8"  0.308922 #############
    4.  "f_4"  0.271175 #########
    5. "f_19"  0.221570 ###
    6. "f_10"  0.215666 ##
    7. "f_11"  0.206509 #
    8. "f_22"  0.204742 #
    9. "f_25"  0.204497 #
   10. "f_23"  0.203238 
   11. "f_21"  0.200830 
   12. "f_24"  0.200445 
   13. "f_12"  0.198840 
   14. "f_18"  0.197676 
   15. "f_20"  0.196634 
   16.  "f_6"  0.196085 
   17. "f_16"  0.196061 
   18.  "f_2"  0.195683 
   19.  "f_5"  0.195683 
   20. "f_13"  0.195559 
   21. "f_17"  0.195559 

Variable Importance: NUM_AS_ROOT:

    1. "f_3"  4.000000 ################
    2. "f_4"  4.000000 ################
    3. "f_8"  3.000000 ##########
    4. "f_9"  1.000000 

Variable Importance: NUM_NODES:

    1.  "f_8" 25.000000 ################
    2. "f_19" 18.000000 ###########
    3. "f_10" 15.000000 #########
    4.  "f_9" 14.000000 ########
    5.  "f_3" 13.000000 ########
    6. "f_23"  7.000000 ####
    7. "f_24"  6.000000 ###
    8. "f_11"  5.000000 ##
    9. "f_21"  5.000000 ##
   10. "f_25"  5.000000 ##
   11.  "f_4"  5.000000 ##
   12. "f_22"  4.000000 ##
   13. "f_12"  3.000000 #
   14. "f_20"  3.000000 #
   15. "f_16"  2.000000 
   16.  "f_6"  2.000000 
   17. "f_13"  1.000000 
   18. "f_17"  1.000000 
   19. "f_18"  1.000000 
   20.  "f_2"  1.000000 
   21.  "f_5"  1.000000 

Variable Importance: SUM_SCORE:

    1.  "f_8" 10779.340861 ################
    2.  "f_9" 8831.772410 #############
    3.  "f_3" 4526.101184 ######
    4.  "f_4" 4360.245403 ######
    5. "f_19" 2325.288894 ###
    6. "f_10" 1881.848369 ##
    7. "f_21" 1674.980191 ##
    8. "f_11" 1127.632256 #
    9. "f_23" 1021.834252 #
   10. "f_24" 914.851512 #
   11. "f_22" 885.619576 #
   12. "f_25" 748.665007 #
   13. "f_20" 310.610858 
   14. "f_16" 298.972842 
   15.  "f_6" 212.376573 
   16. "f_12" 130.725240 
   17.  "f_2" 112.124991 
   18. "f_18" 86.341193 
   19.  "f_5" 65.103908 
   20. "f_13" 57.966947 
   21. "f_17" 21.930388 



Loss: LAMBDA_MART_NDCG5
Validation loss value: -0.438692
Number of trees per iteration: 1
Node format: NOT_SET
Number of trees: 12
Total number of nodes: 286

Number of nodes by tree:
Count: 12 Average: 23.8333 StdDev: 3.50793
Min: 17 Max: 29 Ignored: 0
----------------------------------------------
[ 17, 18) 1   8.33%   8.33% ###
[ 18, 19) 0   0.00%   8.33%
[ 19, 20) 1   8.33%  16.67% ###
[ 20, 21) 0   0.00%  16.67%
[ 21, 22) 2  16.67%  33.33% #######
[ 22, 23) 0   0.00%  33.33%
[ 23, 24) 1   8.33%  41.67% ###
[ 24, 25) 0   0.00%  41.67%
[ 25, 26) 3  25.00%  66.67% ##########
[ 26, 27) 0   0.00%  66.67%
[ 27, 28) 3  25.00%  91.67% ##########
[ 28, 29) 0   0.00%  91.67%
[ 29, 29] 1   8.33% 100.00% ###

Depth by leafs:
Count: 149 Average: 4.14094 StdDev: 1.08696
Min: 1 Max: 5 Ignored: 0
----------------------------------------------
[ 1, 2)  2   1.34%   1.34%
[ 2, 3) 18  12.08%  13.42% ##
[ 3, 4) 13   8.72%  22.15% ##
[ 4, 5) 40  26.85%  48.99% #####
[ 5, 5] 76  51.01% 100.00% ##########

Number of training obs by leaf:
Count: 149 Average: 673.691 StdDev: 2015.44
Min: 5 Max: 8211 Ignored: 0
----------------------------------------------
[    5,  415) 127  85.23%  85.23% ##########
[  415,  825)   6   4.03%  89.26%
[  825, 1236)   2   1.34%  90.60%
[ 1236, 1646)   0   0.00%  90.60%
[ 1646, 2056)   0   0.00%  90.60%
[ 2056, 2467)   1   0.67%  91.28%
[ 2467, 2877)   0   0.00%  91.28%
[ 2877, 3287)   0   0.00%  91.28%
[ 3287, 3698)   1   0.67%  91.95%
[ 3698, 4108)   0   0.00%  91.95%
[ 4108, 4518)   0   0.00%  91.95%
[ 4518, 4929)   1   0.67%  92.62%
[ 4929, 5339)   0   0.00%  92.62%
[ 5339, 5749)   0   0.00%  92.62%
[ 5749, 6160)   1   0.67%  93.29%
[ 6160, 6570)   0   0.00%  93.29%
[ 6570, 6980)   0   0.00%  93.29%
[ 6980, 7391)   0   0.00%  93.29%
[ 7391, 7801)   8   5.37%  98.66% #
[ 7801, 8211]   2   1.34% 100.00%

Attribute in nodes:
    25 : f_8 [NUMERICAL]
    18 : f_19 [NUMERICAL]
    15 : f_10 [NUMERICAL]
    14 : f_9 [NUMERICAL]
    13 : f_3 [NUMERICAL]
    7 : f_23 [NUMERICAL]
    6 : f_24 [NUMERICAL]
    5 : f_4 [NUMERICAL]
    5 : f_25 [NUMERICAL]
    5 : f_21 [NUMERICAL]
    5 : f_11 [NUMERICAL]
    4 : f_22 [NUMERICAL]
    3 : f_20 [NUMERICAL]
    3 : f_12 [NUMERICAL]
    2 : f_6 [NUMERICAL]
    2 : f_16 [NUMERICAL]
    1 : f_5 [NUMERICAL]
    1 : f_2 [NUMERICAL]
    1 : f_18 [NUMERICAL]
    1 : f_17 [NUMERICAL]
    1 : f_13 [NUMERICAL]

Attribute in nodes with depth <= 0:
    4 : f_4 [NUMERICAL]
    4 : f_3 [NUMERICAL]
    3 : f_8 [NUMERICAL]
    1 : f_9 [NUMERICAL]

Attribute in nodes with depth <= 1:
    11 : f_9 [NUMERICAL]
    9 : f_8 [NUMERICAL]
    4 : f_4 [NUMERICAL]
    4 : f_3 [NUMERICAL]
    1 : f_25 [NUMERICAL]
    1 : f_24 [NUMERICAL]
    1 : f_23 [NUMERICAL]
    1 : f_22 [NUMERICAL]
    1 : f_19 [NUMERICAL]
    1 : f_11 [NUMERICAL]

Attribute in nodes with depth <= 2:
    15 : f_8 [NUMERICAL]
    12 : f_9 [NUMERICAL]
    11 : f_3 [NUMERICAL]
    6 : f_19 [NUMERICAL]
    5 : f_4 [NUMERICAL]
    2 : f_25 [NUMERICAL]
    2 : f_11 [NUMERICAL]
    2 : f_10 [NUMERICAL]
    1 : f_24 [NUMERICAL]
    1 : f_23 [NUMERICAL]
    1 : f_22 [NUMERICAL]
    1 : f_18 [NUMERICAL]
    1 : f_17 [NUMERICAL]

Attribute in nodes with depth <= 3:
    22 : f_8 [NUMERICAL]
    13 : f_9 [NUMERICAL]
    11 : f_3 [NUMERICAL]
    10 : f_19 [NUMERICAL]
    9 : f_10 [NUMERICAL]
    5 : f_4 [NUMERICAL]
    5 : f_23 [NUMERICAL]
    5 : f_11 [NUMERICAL]
    4 : f_25 [NUMERICAL]
    4 : f_22 [NUMERICAL]
    4 : f_21 [NUMERICAL]
    3 : f_24 [NUMERICAL]
    2 : f_12 [NUMERICAL]
    1 : f_18 [NUMERICAL]
    1 : f_17 [NUMERICAL]

Attribute in nodes with depth <= 5:
    25 : f_8 [NUMERICAL]
    18 : f_19 [NUMERICAL]
    15 : f_10 [NUMERICAL]
    14 : f_9 [NUMERICAL]
    13 : f_3 [NUMERICAL]
    7 : f_23 [NUMERICAL]
    6 : f_24 [NUMERICAL]
    5 : f_4 [NUMERICAL]
    5 : f_25 [NUMERICAL]
    5 : f_21 [NUMERICAL]
    5 : f_11 [NUMERICAL]
    4 : f_22 [NUMERICAL]
    3 : f_20 [NUMERICAL]
    3 : f_12 [NUMERICAL]
    2 : f_6 [NUMERICAL]
    2 : f_16 [NUMERICAL]
    1 : f_5 [NUMERICAL]
    1 : f_2 [NUMERICAL]
    1 : f_18 [NUMERICAL]
    1 : f_17 [NUMERICAL]
    1 : f_13 [NUMERICAL]

Condition type in nodes:
    137 : HigherCondition
Condition type in nodes with depth <= 0:
    12 : HigherCondition
Condition type in nodes with depth <= 1:
    34 : HigherCondition
Condition type in nodes with depth <= 2:
    60 : HigherCondition
Condition type in nodes with depth <= 3:
    99 : HigherCondition
Condition type in nodes with depth <= 5:
    137 : HigherCondition

Training logs:
Number of iteration to final model: 12
    Iter:1 train-loss:-0.346669 valid-loss:-0.262935  train-NDCG@5:0.346669 valid-NDCG@5:0.262935
    Iter:2 train-loss:-0.412635 valid-loss:-0.335301  train-NDCG@5:0.412635 valid-NDCG@5:0.335301
    Iter:3 train-loss:-0.468270 valid-loss:-0.341295  train-NDCG@5:0.468270 valid-NDCG@5:0.341295
    Iter:4 train-loss:-0.481511 valid-loss:-0.301897  train-NDCG@5:0.481511 valid-NDCG@5:0.301897
    Iter:5 train-loss:-0.473165 valid-loss:-0.394670  train-NDCG@5:0.473165 valid-NDCG@5:0.394670
    Iter:6 train-loss:-0.496260 valid-loss:-0.415201  train-NDCG@5:0.496260 valid-NDCG@5:0.415201
    Iter:16 train-loss:-0.526791 valid-loss:-0.380900  train-NDCG@5:0.526791 valid-NDCG@5:0.380900
    Iter:26 train-loss:-0.560398 valid-loss:-0.367496  train-NDCG@5:0.560398 valid-NDCG@5:0.367496
    Iter:36 train-loss:-0.584252 valid-loss:-0.341845  train-NDCG@5:0.584252 valid-NDCG@5:0.341845

如果您好奇,您也可以绘制模型

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

使用排序模型进行预测

对于一个传入的查询,我们可以使用我们的排名模型来预测一组文档的相关性。在实践中,这意味着对于每个查询,我们必须找到一组可能与查询相关或不相关的文档。我们称这些文档为**候选文档**。对于每个查询/候选文档对,我们可以计算训练期间使用的相同特征。这就是我们的**服务数据集**。

回到本教程开头提到的例子,服务数据集可能看起来像这样

查询 文档 ID 特征 1 特征 2
32 0.3 蓝色
33 1.0 绿色
34 0.4 蓝色
35 NA 棕色

请注意,相关性不属于服务数据集,因为这是模型试图预测的内容。

服务数据集被馈送到 TF-DF 模型,并为每个文档分配一个相关性分数。

查询 文档 ID 特征 1 特征 2 相关性
32 0.3 蓝色 0.325
33 1.0 绿色 0.125
34 0.4 蓝色 0.155
35 NA 棕色 0.593

这意味着文档 ID 为 35 的文档被预测为与查询“鱼”最相关。

让我们尝试使用我们的真实模型来做到这一点。

# Path to a test dataset using libsvm format.
test_dataset_path = os.path.join(os.path.dirname(archive_path),"OHSUMED/Data/Fold1/testset.txt")
# Convert the dataset.
csv_test_dataset_path="/tmp/ohsumed_test.csv"
convert_libsvm_to_csv(raw_dataset_path, csv_test_dataset_path)

# Load a dataset into a Pandas Dataframe.
test_dataset_df = pd.read_csv(csv_test_dataset_path)

# Display the first 3 examples.
test_dataset_df.head(3)

假设我们的查询是“g_5”,并且测试数据集已经包含了此查询的候选文档。

# Filter by "g_5"
serving_dataset_df = test_dataset_df[test_dataset_df['group'] == 'g_5']
# Remove the columns for group and relevance, not needed for predictions.
serving_dataset_df = serving_dataset_df.drop(['relevance', 'group'], axis=1)
# Convert to a Tensorflow dataset
serving_dataset_ds = tfdf.keras.pd_dataframe_to_tf_dataset(serving_dataset_df, task=tfdf.keras.Task.RANKING)
# Run predictions with on all candidate documents
predictions = model.predict(serving_dataset_ds)
1/1 [==============================] - 0s 176ms/step

我们可以将预测添加到数据框中,并使用它们来查找得分最高的文档。

serving_dataset_df['prediction_score'] = predictions
serving_dataset_df.sort_values(by=['prediction_score'], ascending=False).head()