使用 TFR-BERT 进行段落排名

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

TensorFlow 排名可以处理异构稠密和稀疏特征,并且可以扩展到数百万个数据点。但是,构建和部署一个学习排名模型以大规模运行会产生额外的挑战,而不仅仅是设计一个模型。排名库提供了用于构建 分布式训练 的工作流实用程序类,以用于大规模排名应用程序。有关这些功能的更多信息,请参阅 TensorFlow 排名 概述

本教程展示了如何构建一个使用 BERT 进行评分的排名模型。 BERT 是一个非常有效的预训练模块,可以有效地将文本特征编码为上下文化单词嵌入。我们使用 BERT 初始化排名模型,并使用排名损失对模型进行微调。

ANTIQUE 数据集

在本教程中,你将使用 BERT 作为评分函数,为 ANTIQUE(一个问答数据集)构建一个排名模型。来自 Transformer 的双向编码器表示(BERT)是一种基于 Transformer 的机器学习技术,已被证明在许多自然语言处理 (NLP) 任务中很有效。最近关于 TFR-BERT 的工作表明,BERT 是一个有效的评分函数,可用于学习排序任务。

给定一个查询和一个答案列表,排名模型的目标是使用最佳排名相关指标(如 NDCG)对答案进行排名。有关排名指标的更多详细信息,请查看评估指标 离线指标

ANTIQUE 是一个公开可用的开放域非事实问答数据集,收集自 Yahoo! 答案。每个问题都有一个答案列表,其相关性按 0-4 的等级分级,0 表示不相关,4 表示完全相关。列表大小可能因查询而异,因此我们使用固定“列表大小”50,其中列表被截断或用默认值填充。该数据集被分成 2206 个查询用于训练和 200 个查询用于测试。有关更多详细信息,请阅读 arXiv 上的技术论文。

设置

下载并安装 TensorFlow Ranking 和 TensorFlow Model Garden 软件包。

pip install -q tensorflow-ranking tf-models-official

通过笔记本导入 TensorFlow Ranking 和有用的库。

import os
import tensorflow as tf
import tensorflow_ranking as tfr
from official.nlp.configs import encoders
from tensorflow_ranking.extension.premade import tfrbert_task
2022-12-14 12:15:38.646771: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:38.646870: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:38.646879: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

数据准备

下载训练和测试数据。

wget -O "/tmp/train.tfrecords" "https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords"
wget -O "/tmp/test.tfrecords" "https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords"
--2022-12-14 12:15:40--  https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords
Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... 128.119.246.154
Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8743528 (8.3M)
Saving to: ‘/tmp/train.tfrecords’

/tmp/train.tfrecord 100%[===================>]   8.34M  12.0MB/s    in 0.7s    

2022-12-14 12:15:41 (12.0 MB/s) - ‘/tmp/train.tfrecords’ saved [8743528/8743528]

--2022-12-14 12:15:41--  https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords
Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... 128.119.246.154
Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 692072 (676K)
Saving to: ‘/tmp/test.tfrecords’

/tmp/test.tfrecords 100%[===================>] 675.85K  3.93MB/s    in 0.2s    

2022-12-14 12:15:41 (3.93 MB/s) - ‘/tmp/test.tfrecords’ saved [692072/692072]
mkdir -p /tmp/tfrbert
wget "https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz" -P "/tmp/tfrbert"
mkdir -p /tmp/tfrbert/uncased_L-12_H-768_A-12
tar -xvf /tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz --strip-components 3 -C "/tmp/tfrbert/uncased_L-12_H-768_A-12/"
--2022-12-14 12:15:41--  https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.196.128, 142.251.107.128, 142.250.97.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.196.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405351189 (387M) [application/octet-stream]
Saving to: ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’

uncased_L-12_H-768_ 100%[===================>] 386.57M   143MB/s    in 2.7s    

2022-12-14 12:15:44 (143 MB/s) - ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’ saved [405351189/405351189]

tmp/temp_dir/raw/vocab.txt
tmp/temp_dir/raw/bert_model.ckpt.index
tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001
tmp/temp_dir/raw/bert_config.json

Orbit 中 TFR-BERT 概述

当在 MSMARCO 段落排名数据集中针对查询和段落使用原始文本特征时,基于 BERT 的排名模型 (TFR-BERT) 已被证明对学习排序任务很有效。

Orbit 是一个灵活、轻量级的库,旨在使在 TensorFlow 中编写自定义训练循环变得容易。TensorFlow Ranking 提供了对实现排名模型的支持,特别是使用 Orbit 的基于 BERT 的排名模型。

为 TFR-BERT 创建排名任务

我们为 TFR-BERT 模型创建了一个排名任务,可以使用 Orbit 对其进行训练。构建此任务的步骤为

  1. 定义特征规范
  2. 定义数据集
  3. 设置数据和任务配置

指定特征

特征规范 是 TensorFlow 抽象,用于捕获有关每个特征的信息。这些有助于开发人员和模型研究人员理解和使用模型。

为上下文特征、示例特征和标签创建特征规范,与排名输入格式(如 ELWC 格式)一致。

SEQ_LENGTH = 64
context_feature_spec = {}
example_feature_spec = {
    'input_word_ids': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH),
    'input_mask': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH),
    'input_type_ids': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH)}
label_spec = (
    "relevance",
    tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64, default_value=-1)
)

定义数据集

我们为训练和验证数据定义数据配置,它指定路径、批大小和数据集格式等参数。这些配置用于创建训练和验证数据集。

# Set up data config
# We use a small list size here for demo purposes only. Users can use a larger
# list size on a machine with more memory to train TFR-BERT.
train_data_config = tfrbert_task.TFRBertDataConfig(
    input_path="/tmp/train.tfrecords",
    is_training=True,
    global_batch_size=8,
    list_size=2,
    dataset_fn='tfrecord',
    seq_length=64)

validation_data_config = tfrbert_task.TFRBertDataConfig(
    input_path="/tmp/test.tfrecords",
    is_training=False,
    global_batch_size=8,
    list_size=2,
    dataset_fn='tfrecord',
    seq_length=64)

定义任务

我们定义一个任务配置,它定义训练和验证数据集以及模型。此配置创建一个 TFRBertTask 对象,可以使用 Orbit 对其进行训练。

# Set up task config
task_config = tfrbert_task.TFRBertConfig(
    init_checkpoint='/tmp/tfrbert/uncased_L-12_H-768_A-12/bert_model.ckpt',
    train_data=train_data_config,
    validation_data=validation_data_config,
    model=tfrbert_task.TFRBertModelConfig(
        encoder=encoders.EncoderConfig(
            bert=encoders.BertEncoderConfig(num_layers=12))))

# Set up TFRBertTask
task = tfrbert_task.TFRBertTask(
    task_config,
    label_spec=label_spec,
    dataset_fn=tf.data.TFRecordDataset,
    logging_dir='/tmp/model_dir')

训练和评估模型

我们在此定义训练循环以训练和评估模型。我们定义指标、创建训练和评估数据集,并针对特定数量的训练步骤训练模型。

metrics = task.build_metrics()
model = task.build_model()
task.initialize(model)
train_dataset = task.build_inputs(task_config.train_data)
vali_dataset = task.build_inputs(task_config.validation_data)
train_iterator = iter(train_dataset)
vali_iterator = iter(vali_dataset)
optimizer = tf.keras.optimizers.Adam(lr=1e-6)

NUM_TRAIN_STEPS = 100
EVAL_STEPS = 10
for train_step in range(NUM_TRAIN_STEPS):
  task.train_step(next(train_iterator), model, optimizer, metrics=metrics)
  train_metrics = {m.name: m.result().numpy() for m in metrics}
  print("Training metrics for epoch: " + str(train_step) + " ", train_metrics)

  if train_step % EVAL_STEPS == 0:
    task.validation_step(next(train_iterator), model, metrics=metrics)
    vali_metrics = {m.name: m.result().numpy() for m in metrics}
    print("Validation metrics for epoch: " + str(train_step) + " ",
          vali_metrics)
2022-12-14 12:15:49.221104: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:absl:`lr` is deprecated, please use `learning_rate` instead, or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.
Training metrics for epoch: 0  {'MAP': 0.9375, 'NDCG@1': 0.73214287, 'NDCG@5': 0.912364, 'NDCG@10': 0.912364, 'MRR@1': 0.875, 'MRR@5': 0.9375, 'MRR@10': 0.9375}
Validation metrics for epoch: 0  {'MAP': 0.96875, 'NDCG@1': 0.66369045, 'NDCG@5': 0.89421266, 'NDCG@10': 0.89421266, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}
Training metrics for epoch: 1  {'MAP': 0.9583333, 'NDCG@1': 0.6507936, 'NDCG@5': 0.88817185, 'NDCG@10': 0.88817185, 'MRR@1': 0.9166667, 'MRR@5': 0.9583333, 'MRR@10': 0.9583333}
Training metrics for epoch: 2  {'MAP': 0.96875, 'NDCG@1': 0.6577381, 'NDCG@5': 0.89149714, 'NDCG@10': 0.89149714, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}
Training metrics for epoch: 3  {'MAP': 0.975, 'NDCG@1': 0.68095237, 'NDCG@5': 0.89981496, 'NDCG@10': 0.89981496, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}
Training metrics for epoch: 4  {'MAP': 0.9791667, 'NDCG@1': 0.71031743, 'NDCG@5': 0.9095955, 'NDCG@10': 0.9095955, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 5  {'MAP': 0.98214287, 'NDCG@1': 0.7091837, 'NDCG@5': 0.9085163, 'NDCG@10': 0.9085163, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 6  {'MAP': 0.9765625, 'NDCG@1': 0.68526787, 'NDCG@5': 0.8999288, 'NDCG@10': 0.8999288, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}
Training metrics for epoch: 7  {'MAP': 0.9791667, 'NDCG@1': 0.7030423, 'NDCG@5': 0.90591866, 'NDCG@10': 0.90591866, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 8  {'MAP': 0.98125, 'NDCG@1': 0.7255953, 'NDCG@5': 0.9132517, 'NDCG@10': 0.9132517, 'MRR@1': 0.9625, 'MRR@5': 0.98125, 'MRR@10': 0.98125}
Training metrics for epoch: 9  {'MAP': 0.97727275, 'NDCG@1': 0.7229437, 'NDCG@5': 0.9117598, 'NDCG@10': 0.9117598, 'MRR@1': 0.95454544, 'MRR@5': 0.97727275, 'MRR@10': 0.97727275}
Training metrics for epoch: 10  {'MAP': 0.9791667, 'NDCG@1': 0.7311508, 'NDCG@5': 0.91436106, 'NDCG@10': 0.91436106, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Validation metrics for epoch: 10  {'MAP': 0.97596157, 'NDCG@1': 0.7339744, 'NDCG@5': 0.9146096, 'NDCG@10': 0.9146096, 'MRR@1': 0.9519231, 'MRR@5': 0.97596157, 'MRR@10': 0.97596157}
Training metrics for epoch: 11  {'MAP': 0.9776786, 'NDCG@1': 0.73511904, 'NDCG@5': 0.9151535, 'NDCG@10': 0.9151535, 'MRR@1': 0.95535713, 'MRR@5': 0.9776786, 'MRR@10': 0.9776786}
Training metrics for epoch: 12  {'MAP': 0.975, 'NDCG@1': 0.7253969, 'NDCG@5': 0.91220075, 'NDCG@10': 0.91220075, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}
Training metrics for epoch: 13  {'MAP': 0.9765625, 'NDCG@1': 0.73363096, 'NDCG@5': 0.9150943, 'NDCG@10': 0.9150943, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}
Training metrics for epoch: 14  {'MAP': 0.97794116, 'NDCG@1': 0.7366947, 'NDCG@5': 0.91642684, 'NDCG@10': 0.91642684, 'MRR@1': 0.9558824, 'MRR@5': 0.97794116, 'MRR@10': 0.97794116}
Training metrics for epoch: 15  {'MAP': 0.9791667, 'NDCG@1': 0.7433862, 'NDCG@5': 0.9187641, 'NDCG@10': 0.9187641, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 16  {'MAP': 0.9769737, 'NDCG@1': 0.73903507, 'NDCG@5': 0.91679335, 'NDCG@10': 0.91679335, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}
Training metrics for epoch: 17  {'MAP': 0.978125, 'NDCG@1': 0.7407738, 'NDCG@5': 0.91760796, 'NDCG@10': 0.91760796, 'MRR@1': 0.95625, 'MRR@5': 0.978125, 'MRR@10': 0.978125}
Training metrics for epoch: 18  {'MAP': 0.9791667, 'NDCG@1': 0.73781174, 'NDCG@5': 0.9168396, 'NDCG@10': 0.9168396, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 19  {'MAP': 0.9801136, 'NDCG@1': 0.7464827, 'NDCG@5': 0.91967636, 'NDCG@10': 0.91967636, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.9801136}
Training metrics for epoch: 20  {'MAP': 0.9782609, 'NDCG@1': 0.7380952, 'NDCG@5': 0.91687906, 'NDCG@10': 0.91687906, 'MRR@1': 0.95652175, 'MRR@5': 0.9782609, 'MRR@10': 0.9782609}
Validation metrics for epoch: 20  {'MAP': 0.9791667, 'NDCG@1': 0.74900794, 'NDCG@5': 0.92034245, 'NDCG@10': 0.92034245, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 21  {'MAP': 0.98, 'NDCG@1': 0.7561905, 'NDCG@5': 0.9226987, 'NDCG@10': 0.9226987, 'MRR@1': 0.96, 'MRR@5': 0.98, 'MRR@10': 0.98}
Training metrics for epoch: 22  {'MAP': 0.9807692, 'NDCG@1': 0.7545787, 'NDCG@5': 0.92208344, 'NDCG@10': 0.92208344, 'MRR@1': 0.96153843, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}
Training metrics for epoch: 23  {'MAP': 0.9814815, 'NDCG@1': 0.75837743, 'NDCG@5': 0.9234321, 'NDCG@10': 0.9234321, 'MRR@1': 0.962963, 'MRR@5': 0.9814815, 'MRR@10': 0.9814815}
Training metrics for epoch: 24  {'MAP': 0.98214287, 'NDCG@1': 0.7568027, 'NDCG@5': 0.9228346, 'NDCG@10': 0.9228346, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 25  {'MAP': 0.98275864, 'NDCG@1': 0.7627258, 'NDCG@5': 0.9247799, 'NDCG@10': 0.9247799, 'MRR@1': 0.9655172, 'MRR@5': 0.98275864, 'MRR@10': 0.98275864}
Training metrics for epoch: 26  {'MAP': 0.98333335, 'NDCG@1': 0.76468253, 'NDCG@5': 0.9253864, 'NDCG@10': 0.9253864, 'MRR@1': 0.96666664, 'MRR@5': 0.98333335, 'MRR@10': 0.98333335}
Training metrics for epoch: 27  {'MAP': 0.983871, 'NDCG@1': 0.765361, 'NDCG@5': 0.9257851, 'NDCG@10': 0.9257851, 'MRR@1': 0.9677419, 'MRR@5': 0.983871, 'MRR@10': 0.983871}
Training metrics for epoch: 28  {'MAP': 0.984375, 'NDCG@1': 0.7671131, 'NDCG@5': 0.92632234, 'NDCG@10': 0.92632234, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 29  {'MAP': 0.98295456, 'NDCG@1': 0.7638889, 'NDCG@5': 0.92527056, 'NDCG@10': 0.92527056, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}
Training metrics for epoch: 30  {'MAP': 0.9834559, 'NDCG@1': 0.7708334, 'NDCG@5': 0.9274685, 'NDCG@10': 0.9274685, 'MRR@1': 0.9669118, 'MRR@5': 0.9834559, 'MRR@10': 0.9834559}
Validation metrics for epoch: 30  {'MAP': 0.98392856, 'NDCG@1': 0.76513606, 'NDCG@5': 0.9256894, 'NDCG@10': 0.9262823, 'MRR@1': 0.9678571, 'MRR@5': 0.98392856, 'MRR@10': 0.98392856}
Training metrics for epoch: 31  {'MAP': 0.984375, 'NDCG@1': 0.765377, 'NDCG@5': 0.92589486, 'NDCG@10': 0.9264713, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 32  {'MAP': 0.9831081, 'NDCG@1': 0.765444, 'NDCG@5': 0.92567044, 'NDCG@10': 0.92623127, 'MRR@1': 0.9662162, 'MRR@5': 0.9831081, 'MRR@10': 0.9831081}
Training metrics for epoch: 33  {'MAP': 0.98355263, 'NDCG@1': 0.768797, 'NDCG@5': 0.92667186, 'NDCG@10': 0.92721796, 'MRR@1': 0.96710527, 'MRR@5': 0.98355263, 'MRR@10': 0.98355263}
Training metrics for epoch: 34  {'MAP': 0.98397434, 'NDCG@1': 0.7728938, 'NDCG@5': 0.92802, 'NDCG@10': 0.9285521, 'MRR@1': 0.96794873, 'MRR@5': 0.98397434, 'MRR@10': 0.98397434}
Training metrics for epoch: 35  {'MAP': 0.984375, 'NDCG@1': 0.775, 'NDCG@5': 0.92878187, 'NDCG@10': 0.92930067, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 36  {'MAP': 0.9832317, 'NDCG@1': 0.77482575, 'NDCG@5': 0.92850894, 'NDCG@10': 0.9290151, 'MRR@1': 0.9664634, 'MRR@5': 0.9832317, 'MRR@10': 0.9832317}
Training metrics for epoch: 37  {'MAP': 0.98214287, 'NDCG@1': 0.770975, 'NDCG@5': 0.9271499, 'NDCG@10': 0.927644, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 38  {'MAP': 0.98255813, 'NDCG@1': 0.7743632, 'NDCG@5': 0.9282532, 'NDCG@10': 0.9287358, 'MRR@1': 0.96511626, 'MRR@5': 0.98255813, 'MRR@10': 0.98255813}
Training metrics for epoch: 39  {'MAP': 0.98295456, 'NDCG@1': 0.7746212, 'NDCG@5': 0.928235, 'NDCG@10': 0.9287066, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}
Training metrics for epoch: 40  {'MAP': 0.98194444, 'NDCG@1': 0.77685183, 'NDCG@5': 0.9288045, 'NDCG@10': 0.9292657, 'MRR@1': 0.9638889, 'MRR@5': 0.98194444, 'MRR@10': 0.98194444}
Validation metrics for epoch: 40  {'MAP': 0.98097825, 'NDCG@1': 0.7758799, 'NDCG@5': 0.9284471, 'NDCG@10': 0.9288983, 'MRR@1': 0.9619565, 'MRR@5': 0.98233694, 'MRR@10': 0.98233694}
Training metrics for epoch: 41  {'MAP': 0.98138297, 'NDCG@1': 0.7730496, 'NDCG@5': 0.927543, 'NDCG@10': 0.92798454, 'MRR@1': 0.96276593, 'MRR@5': 0.98271275, 'MRR@10': 0.98271275}
Training metrics for epoch: 42  {'MAP': 0.9791667, 'NDCG@1': 0.7703373, 'NDCG@5': 0.9263746, 'NDCG@10': 0.9268069, 'MRR@1': 0.9583333, 'MRR@5': 0.98046875, 'MRR@10': 0.98046875}
Training metrics for epoch: 43  {'MAP': 0.97959185, 'NDCG@1': 0.77502424, 'NDCG@5': 0.9278771, 'NDCG@10': 0.9283007, 'MRR@1': 0.9591837, 'MRR@5': 0.9808673, 'MRR@10': 0.9808673}
Training metrics for epoch: 44  {'MAP': 0.97875, 'NDCG@1': 0.7705952, 'NDCG@5': 0.9264264, 'NDCG@10': 0.92684144, 'MRR@1': 0.9575, 'MRR@5': 0.98, 'MRR@10': 0.98}
Training metrics for epoch: 45  {'MAP': 0.9791667, 'NDCG@1': 0.7708916, 'NDCG@5': 0.9264465, 'NDCG@10': 0.9268534, 'MRR@1': 0.9583333, 'MRR@5': 0.98039216, 'MRR@10': 0.98039216}
Training metrics for epoch: 46  {'MAP': 0.9795673, 'NDCG@1': 0.7732371, 'NDCG@5': 0.9271634, 'NDCG@10': 0.9275625, 'MRR@1': 0.95913464, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}
Training metrics for epoch: 47  {'MAP': 0.9799528, 'NDCG@1': 0.77279866, 'NDCG@5': 0.9270702, 'NDCG@10': 0.92746174, 'MRR@1': 0.9599057, 'MRR@5': 0.9811321, 'MRR@10': 0.9811321}
Training metrics for epoch: 48  {'MAP': 0.9780093, 'NDCG@1': 0.7718253, 'NDCG@5': 0.92525107, 'NDCG@10': 0.9256354, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 49  {'MAP': 0.9784091, 'NDCG@1': 0.77077913, 'NDCG@5': 0.925101, 'NDCG@10': 0.9254783, 'MRR@1': 0.9590909, 'MRR@5': 0.9795455, 'MRR@10': 0.9795455}
Training metrics for epoch: 50  {'MAP': 0.9776786, 'NDCG@1': 0.7694515, 'NDCG@5': 0.92459637, 'NDCG@10': 0.92496693, 'MRR@1': 0.95758927, 'MRR@5': 0.97879463, 'MRR@10': 0.97879463}
Validation metrics for epoch: 50  {'MAP': 0.9780702, 'NDCG@1': 0.7703634, 'NDCG@5': 0.9249188, 'NDCG@10': 0.92528284, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 51  {'MAP': 0.9773707, 'NDCG@1': 0.7690887, 'NDCG@5': 0.9244347, 'NDCG@10': 0.9247925, 'MRR@1': 0.95689654, 'MRR@5': 0.9784483, 'MRR@10': 0.9784483}
Training metrics for epoch: 52  {'MAP': 0.97775424, 'NDCG@1': 0.76432604, 'NDCG@5': 0.9228256, 'NDCG@10': 0.9231773, 'MRR@1': 0.9576271, 'MRR@5': 0.9788136, 'MRR@10': 0.9788136}
Training metrics for epoch: 53  {'MAP': 0.978125, 'NDCG@1': 0.7664682, 'NDCG@5': 0.9235073, 'NDCG@10': 0.9238531, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 54  {'MAP': 0.9784836, 'NDCG@1': 0.7661983, 'NDCG@5': 0.9234862, 'NDCG@10': 0.9238264, 'MRR@1': 0.9590164, 'MRR@5': 0.9795082, 'MRR@10': 0.9795082}
Training metrics for epoch: 55  {'MAP': 0.97883064, 'NDCG@1': 0.76766515, 'NDCG@5': 0.92405087, 'NDCG@10': 0.92438555, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9798387}
Training metrics for epoch: 56  {'MAP': 0.9791667, 'NDCG@1': 0.7696523, 'NDCG@5': 0.9246806, 'NDCG@10': 0.92501, 'MRR@1': 0.96031743, 'MRR@5': 0.98015875, 'MRR@10': 0.98015875}
Training metrics for epoch: 57  {'MAP': 0.9785156, 'NDCG@1': 0.7679501, 'NDCG@5': 0.92400306, 'NDCG@10': 0.9243273, 'MRR@1': 0.9589844, 'MRR@5': 0.9794922, 'MRR@10': 0.9794922}
Training metrics for epoch: 58  {'MAP': 0.97884613, 'NDCG@1': 0.76767397, 'NDCG@5': 0.92397565, 'NDCG@10': 0.92429495, 'MRR@1': 0.9596154, 'MRR@5': 0.9798077, 'MRR@10': 0.9798077}
Training metrics for epoch: 59  {'MAP': 0.9782197, 'NDCG@1': 0.7655122, 'NDCG@5': 0.92325014, 'NDCG@10': 0.92356455, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 60  {'MAP': 0.97761196, 'NDCG@1': 0.76341504, 'NDCG@5': 0.92254627, 'NDCG@10': 0.92285603, 'MRR@1': 0.95708954, 'MRR@5': 0.9785448, 'MRR@10': 0.9785448}
Validation metrics for epoch: 60  {'MAP': 0.97794116, 'NDCG@1': 0.76216733, 'NDCG@5': 0.9222364, 'NDCG@10': 0.9227017, 'MRR@1': 0.9577206, 'MRR@5': 0.9788603, 'MRR@10': 0.9788603}
Training metrics for epoch: 61  {'MAP': 0.9782609, 'NDCG@1': 0.7623361, 'NDCG@5': 0.9223936, 'NDCG@10': 0.9228522, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 62  {'MAP': 0.9785714, 'NDCG@1': 0.76164967, 'NDCG@5': 0.92231655, 'NDCG@10': 0.92276853, 'MRR@1': 0.9589286, 'MRR@5': 0.9794643, 'MRR@10': 0.9794643}
Training metrics for epoch: 63  {'MAP': 0.97711265, 'NDCG@1': 0.75796443, 'NDCG@5': 0.9210157, 'NDCG@10': 0.9214613, 'MRR@1': 0.9559859, 'MRR@5': 0.97799295, 'MRR@10': 0.97799295}
Training metrics for epoch: 64  {'MAP': 0.9765625, 'NDCG@1': 0.7581018, 'NDCG@5': 0.9209682, 'NDCG@10': 0.9214076, 'MRR@1': 0.9548611, 'MRR@5': 0.9774306, 'MRR@10': 0.9774306}
Training metrics for epoch: 65  {'MAP': 0.97602737, 'NDCG@1': 0.7577462, 'NDCG@5': 0.9208503, 'NDCG@10': 0.92128366, 'MRR@1': 0.9537671, 'MRR@5': 0.9768836, 'MRR@10': 0.9768836}
Training metrics for epoch: 66  {'MAP': 0.9755068, 'NDCG@1': 0.75546974, 'NDCG@5': 0.9200356, 'NDCG@10': 0.92046314, 'MRR@1': 0.9527027, 'MRR@5': 0.9763514, 'MRR@10': 0.9763514}
Training metrics for epoch: 67  {'MAP': 0.97583336, 'NDCG@1': 0.7563492, 'NDCG@5': 0.9203415, 'NDCG@10': 0.9207634, 'MRR@1': 0.9533333, 'MRR@5': 0.9766667, 'MRR@10': 0.9766667}
Training metrics for epoch: 68  {'MAP': 0.9761513, 'NDCG@1': 0.7581454, 'NDCG@5': 0.9209124, 'NDCG@10': 0.9213287, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}
Training metrics for epoch: 69  {'MAP': 0.97564936, 'NDCG@1': 0.7578077, 'NDCG@5': 0.92080134, 'NDCG@10': 0.92121226, 'MRR@1': 0.9529221, 'MRR@5': 0.97646105, 'MRR@10': 0.97646105}
Training metrics for epoch: 70  {'MAP': 0.97596157, 'NDCG@1': 0.75862336, 'NDCG@5': 0.92108566, 'NDCG@10': 0.92149127, 'MRR@1': 0.95352566, 'MRR@5': 0.97676283, 'MRR@10': 0.97676283}
Validation metrics for epoch: 70  {'MAP': 0.97626585, 'NDCG@1': 0.75761, 'NDCG@5': 0.9207071, 'NDCG@10': 0.9211076, 'MRR@1': 0.9541139, 'MRR@5': 0.977057, 'MRR@10': 0.977057}
Training metrics for epoch: 71  {'MAP': 0.9765625, 'NDCG@1': 0.75796133, 'NDCG@5': 0.9209201, 'NDCG@10': 0.92131555, 'MRR@1': 0.9546875, 'MRR@5': 0.97734374, 'MRR@10': 0.97734374}
Training metrics for epoch: 72  {'MAP': 0.9768519, 'NDCG@1': 0.75727516, 'NDCG@5': 0.92068696, 'NDCG@10': 0.9210776, 'MRR@1': 0.9552469, 'MRR@5': 0.97762346, 'MRR@10': 0.97762346}
Training metrics for epoch: 73  {'MAP': 0.97713417, 'NDCG@1': 0.7589286, 'NDCG@5': 0.9212119, 'NDCG@10': 0.9215977, 'MRR@1': 0.95579267, 'MRR@5': 0.97789633, 'MRR@10': 0.97789633}
Training metrics for epoch: 74  {'MAP': 0.97740966, 'NDCG@1': 0.7618331, 'NDCG@5': 0.92216116, 'NDCG@10': 0.92254233, 'MRR@1': 0.9563253, 'MRR@5': 0.97816265, 'MRR@10': 0.97816265}
Training metrics for epoch: 75  {'MAP': 0.9776786, 'NDCG@1': 0.761267, 'NDCG@5': 0.9219771, 'NDCG@10': 0.92235374, 'MRR@1': 0.9568452, 'MRR@5': 0.97842264, 'MRR@10': 0.97842264}
Training metrics for epoch: 76  {'MAP': 0.97794116, 'NDCG@1': 0.759874, 'NDCG@5': 0.9215532, 'NDCG@10': 0.9219254, 'MRR@1': 0.95735294, 'MRR@5': 0.9786765, 'MRR@10': 0.9786765}
Training metrics for epoch: 77  {'MAP': 0.9781977, 'NDCG@1': 0.75754434, 'NDCG@5': 0.9208438, 'NDCG@10': 0.92121166, 'MRR@1': 0.95784885, 'MRR@5': 0.9789244, 'MRR@10': 0.9789244}
Training metrics for epoch: 78  {'MAP': 0.9784483, 'NDCG@1': 0.7573208, 'NDCG@5': 0.9208061, 'NDCG@10': 0.92116976, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 79  {'MAP': 0.9786932, 'NDCG@1': 0.758861, 'NDCG@5': 0.92129385, 'NDCG@10': 0.9216534, 'MRR@1': 0.9588068, 'MRR@5': 0.97940344, 'MRR@10': 0.97940344}
Training metrics for epoch: 80  {'MAP': 0.97893256, 'NDCG@1': 0.7603666, 'NDCG@5': 0.9217707, 'NDCG@10': 0.9221262, 'MRR@1': 0.95926964, 'MRR@5': 0.9796348, 'MRR@10': 0.9796348}
Validation metrics for epoch: 80  {'MAP': 0.9791667, 'NDCG@1': 0.75925934, 'NDCG@5': 0.92172426, 'NDCG@10': 0.92167276, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.97986114}
Training metrics for epoch: 81  {'MAP': 0.9793956, 'NDCG@1': 0.7595501, 'NDCG@5': 0.9217872, 'NDCG@10': 0.92173624, 'MRR@1': 0.9587912, 'MRR@5': 0.9793956, 'MRR@10': 0.9800824}
Training metrics for epoch: 82  {'MAP': 0.97961956, 'NDCG@1': 0.7606108, 'NDCG@5': 0.92218626, 'NDCG@10': 0.92213583, 'MRR@1': 0.9592391, 'MRR@5': 0.97961956, 'MRR@10': 0.98029894}
Training metrics for epoch: 83  {'MAP': 0.9798387, 'NDCG@1': 0.7601127, 'NDCG@5': 0.9220197, 'NDCG@10': 0.92196983, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9805108}
Training metrics for epoch: 84  {'MAP': 0.9800532, 'NDCG@1': 0.7596252, 'NDCG@5': 0.92185676, 'NDCG@10': 0.9218074, 'MRR@1': 0.9601064, 'MRR@5': 0.9800532, 'MRR@10': 0.9807181}
Training metrics for epoch: 85  {'MAP': 0.9802632, 'NDCG@1': 0.7612783, 'NDCG@5': 0.9224118, 'NDCG@10': 0.922363, 'MRR@1': 0.9605263, 'MRR@5': 0.9802632, 'MRR@10': 0.98092103}
Training metrics for epoch: 86  {'MAP': 0.98046875, 'NDCG@1': 0.7626489, 'NDCG@5': 0.9228422, 'NDCG@10': 0.92279387, 'MRR@1': 0.9609375, 'MRR@5': 0.98046875, 'MRR@10': 0.9811198}
Training metrics for epoch: 87  {'MAP': 0.9806701, 'NDCG@1': 0.7639913, 'NDCG@5': 0.9232637, 'NDCG@10': 0.92321587, 'MRR@1': 0.9613402, 'MRR@5': 0.9806701, 'MRR@10': 0.9813144}
Training metrics for epoch: 88  {'MAP': 0.9808673, 'NDCG@1': 0.76372707, 'NDCG@5': 0.92320555, 'NDCG@10': 0.9231582, 'MRR@1': 0.9617347, 'MRR@5': 0.9808673, 'MRR@10': 0.9815051}
Training metrics for epoch: 89  {'MAP': 0.9810606, 'NDCG@1': 0.76611364, 'NDCG@5': 0.92398125, 'NDCG@10': 0.9239344, 'MRR@1': 0.9621212, 'MRR@5': 0.9810606, 'MRR@10': 0.9816919}
Training metrics for epoch: 90  {'MAP': 0.980625, 'NDCG@1': 0.7645835, 'NDCG@5': 0.9234557, 'NDCG@10': 0.92340934, 'MRR@1': 0.96125, 'MRR@5': 0.980625, 'MRR@10': 0.98125}
Validation metrics for epoch: 90  {'MAP': 0.98081684, 'NDCG@1': 0.7650285, 'NDCG@5': 0.923649, 'NDCG@10': 0.9235569, 'MRR@1': 0.9616337, 'MRR@5': 0.98081684, 'MRR@10': 0.98143566}
Training metrics for epoch: 91  {'MAP': 0.9810049, 'NDCG@1': 0.76581484, 'NDCG@5': 0.92394495, 'NDCG@10': 0.92385375, 'MRR@1': 0.9620098, 'MRR@5': 0.9810049, 'MRR@10': 0.9816176}
Training metrics for epoch: 92  {'MAP': 0.9811893, 'NDCG@1': 0.7662392, 'NDCG@5': 0.9240845, 'NDCG@10': 0.9239942, 'MRR@1': 0.9623786, 'MRR@5': 0.9811893, 'MRR@10': 0.98179615}
Training metrics for epoch: 93  {'MAP': 0.9813702, 'NDCG@1': 0.7678001, 'NDCG@5': 0.9246149, 'NDCG@10': 0.9245255, 'MRR@1': 0.96274036, 'MRR@5': 0.9813702, 'MRR@10': 0.98197114}
Training metrics for epoch: 94  {'MAP': 0.9815476, 'NDCG@1': 0.7676306, 'NDCG@5': 0.92459214, 'NDCG@10': 0.92450356, 'MRR@1': 0.96309525, 'MRR@5': 0.9815476, 'MRR@10': 0.98214287}
Training metrics for epoch: 95  {'MAP': 0.9817217, 'NDCG@1': 0.7698228, 'NDCG@5': 0.9253036, 'NDCG@10': 0.92521584, 'MRR@1': 0.9634434, 'MRR@5': 0.9817217, 'MRR@10': 0.9823113}
Training metrics for epoch: 96  {'MAP': 0.9807243, 'NDCG@1': 0.7683023, 'NDCG@5': 0.9247515, 'NDCG@10': 0.92466456, 'MRR@1': 0.9614486, 'MRR@5': 0.9807243, 'MRR@10': 0.9813084}
Training metrics for epoch: 97  {'MAP': 0.9797454, 'NDCG@1': 0.76714087, 'NDCG@5': 0.92425805, 'NDCG@10': 0.9241719, 'MRR@1': 0.9594907, 'MRR@5': 0.9797454, 'MRR@10': 0.9803241}
Training metrics for epoch: 98  {'MAP': 0.9799312, 'NDCG@1': 0.7682942, 'NDCG@5': 0.9246202, 'NDCG@10': 0.92453486, 'MRR@1': 0.9598624, 'MRR@5': 0.9799312, 'MRR@10': 0.9805046}
Training metrics for epoch: 99  {'MAP': 0.9801136, 'NDCG@1': 0.767695, 'NDCG@5': 0.9244149, 'NDCG@10': 0.92433035, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.98068184}