通用句子编码器

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看 下载笔记本 查看 TF Hub 模型

此笔记本演示了如何访问通用句子编码器并将其用于句子相似性和句子分类任务。

通用句子编码器使获取句子级嵌入变得像以前查找单个单词的嵌入一样容易。然后,句子嵌入可以轻松地用于计算句子级含义相似性,以及使用较少的监督训练数据来提高下游分类任务的性能。

设置

本节设置了访问 TF Hub 上的通用句子编码器的环境,并提供了将编码器应用于单词、句子和段落的示例。

%%capture
!pip3 install seaborn

有关安装 Tensorflow 的更详细的信息,请访问 https://tensorflowcn.cn/install/

加载通用句子编码器的 TF Hub 模块

2024-03-10 12:03:32.159319: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
module https://tfhub.dev/google/universal-sentence-encoder/4 loaded

计算每个消息的表示,显示支持的各种长度。

Message: Elephant
Embedding size: 512
Embedding: [0.008344484493136406, 0.0004808559315279126, 0.06595249474048615, ...]

Message: I am a sentence for which I would like to get its embedding.
Embedding size: 512
Embedding: [0.050808604806661606, -0.016524329781532288, 0.01573779620230198, ...]

Message: Universal Sentence Encoder embeddings also support short paragraphs. There is no hard limit on how long the paragraph is. Roughly, the longer the more 'diluted' the embedding will be.
Embedding size: 512
Embedding: [-0.028332693502306938, -0.0558621808886528, -0.012941482476890087, ...]

语义文本相似性任务示例

通用句子编码器生成的嵌入是近似归一化的。两个句子的语义相似性可以轻松地计算为编码的内积。

def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  sns.set(font_scale=1.2)
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlOrRd")
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)

可视化相似性

这里我们显示了热图中的相似性。最终图是一个 9x9 矩阵,其中每个条目 [i, j] 根据句子 ij 的编码的内积进行着色。

messages = [
    # Smartphones
    "I like my phone",
    "My phone is not good.",
    "Your cellphone looks great.",

    # Weather
    "Will it snow tomorrow?",
    "Recently a lot of hurricanes have hit the US",
    "Global warming is real",

    # Food and health
    "An apple a day, keeps the doctors away",
    "Eating strawberries is healthy",
    "Is paleo better than keto?",

    # Asking about age
    "How old are you?",
    "what is your age?",
]

run_and_plot(messages)

png

评估:STS(语义文本相似性)基准

STS 基准 提供了对使用句子嵌入计算的相似性分数与人类判断一致程度的内在评估。该基准要求系统为各种句子对返回相似性分数。然后使用 皮尔逊相关系数 来评估机器相似性分数与人类判断的质量。

下载数据

import pandas
import scipy
import math
import csv

sts_dataset = tf.keras.utils.get_file(
    fname="Stsbenchmark.tar.gz",
    origin="http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz",
    extract=True)
sts_dev = pandas.read_table(
    os.path.join(os.path.dirname(sts_dataset), "stsbenchmark", "sts-dev.csv"),
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
sts_test = pandas.read_table(
    os.path.join(
        os.path.dirname(sts_dataset), "stsbenchmark", "sts-test.csv"),
    quoting=csv.QUOTE_NONE,
    skip_blank_lines=True,
    usecols=[4, 5, 6],
    names=["sim", "sent_1", "sent_2"])
# cleanup some NaN values in sts_dev
sts_dev = sts_dev[[isinstance(s, str) for s in sts_dev['sent_2']]]
Downloading data from http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz
409630/409630 ━━━━━━━━━━━━━━━━━━━━ 1s 2us/step

评估句子嵌入

sts_data = sts_dev

def run_sts_benchmark(batch):
  sts_encode1 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_1'].tolist())), axis=1)
  sts_encode2 = tf.nn.l2_normalize(embed(tf.constant(batch['sent_2'].tolist())), axis=1)
  cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
  clip_cosine_similarities = tf.clip_by_value(cosine_similarities, -1.0, 1.0)
  scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi
  """Returns the similarity scores"""
  return scores

dev_scores = sts_data['sim'].tolist()
scores = []
for batch in np.array_split(sts_data, 10):
  scores.extend(run_sts_benchmark(batch))

pearson_correlation = scipy.stats.pearsonr(scores, dev_scores)
print('Pearson correlation coefficient = {0}\np-value = {1}'.format(
    pearson_correlation[0], pearson_correlation[1]))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/numpy/core/fromnumeric.py:59: FutureWarning: 'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
  return bound(*args, **kwds)
Pearson correlation coefficient = 0.8036396940028219
p-value = 0.0