使用 TensorFlow Lite 模型制作器进行文本搜索

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本 查看 TF Hub 模型

在这个 Colab 笔记本中,您可以学习如何使用 TensorFlow Lite 模型制作器 库来创建 TFLite 搜索器模型。您可以使用文本搜索器模型为您的应用程序构建语义搜索或智能回复。这种类型的模型允许您输入文本查询,并在文本数据集中搜索最相关的条目,例如网页数据库。该模型返回数据集中得分最小的距离条目列表,包括您指定的元数据,例如 URL、页面标题或其他文本条目标识符。构建完成后,您可以使用 任务库搜索器 API 将其部署到设备(例如 Android)上,只需几行代码即可运行推理。

本教程利用 CNN/DailyMail 数据集作为示例来创建 TFLite 搜索器模型。您可以尝试使用您自己的数据集,该数据集应具有兼容的输入逗号分隔值 (CSV) 格式。

使用可扩展最近邻进行文本搜索

本教程使用公开可用的 CNN/DailyMail 非匿名摘要数据集,该数据集来自 GitHub 仓库。该数据集包含超过 30 万篇新闻文章,使其成为构建搜索器模型的良好数据集,并在模型推理期间为文本查询返回各种相关新闻。

本示例中的文本搜索器模型使用 ScaNN(可扩展最近邻)索引文件,该文件可以从预定义的数据库中搜索类似的项目。ScaNN 在大规模高效向量相似性搜索方面实现了最先进的性能。

本 Colab 中使用此数据集中的亮点和 URL 来创建模型

  1. 亮点是用于生成嵌入特征向量的文本,然后用于搜索。
  2. URL 是在搜索相关亮点后显示给用户的返回结果。

本教程将这些数据保存到 CSV 文件中,然后使用 CSV 文件构建模型。以下是数据集中的几个示例。

亮点 URL
夏威夷航空公司再次在准点率方面排名第一。航空公司质量排名报告考察了美国 14 家最大的航空公司。ExpressJet
和美国航空公司的准点率最差。维珍美国航空公司的行李处理最好;西南航空公司的投诉率最低。
http://www.cnn.com/2013/04/08/travel/airline-quality-report
欧洲足球管理机构公布了竞标举办 2020 年决赛的国家名单。第 60 届决赛将由 13 个国家举办。
32 个国家正在考虑竞标举办 2020 年比赛。欧足联将于 9 月 25 日宣布举办城市。
http://edition.cnn.com:80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html?
曾经的章鱼猎手迪伦·梅耶现在也签署了一份由 5000 名潜水员签署的请愿书,禁止他们在西克雷斯特公园狩猎章鱼。华盛顿州
鱼类和野生动物部门的决定可能需要几个月时间。
http://www.dailymail.co.uk:80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html?
在宇宙大爆炸后 4.2 亿年,人们观测到了一个星系。这个星系是由美国宇航局的哈勃太空望远镜、斯皮策太空望远镜以及太空中的一个天然“变焦镜头”发现的。
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html

设置

首先安装所需的软件包,包括来自 GitHub 仓库 的 Model Maker 软件包。

sudo apt -y install libportaudio2
pip install -q tflite-model-maker
pip install gdown

导入所需的软件包。

from tflite_model_maker import searcher

准备数据集

本教程使用来自 GitHub 仓库 的 CNN/Daily Mail 摘要数据集。

首先,下载 cnn 和 dailymail 的文本和 url 并解压缩它们。如果无法从 Google Drive 下载,请等待几分钟再试一次,或者手动下载并将其上传到 Colab。

gdown https://drive.google.com/uc?id=0BwmD_VLjROrfTHk4NFg2SndKcjQ
gdown https://drive.google.com/uc?id=0BwmD_VLjROrfM1BxdkxVaTY2bWs

wget -O all_train.txt https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt
tar xzf cnn_stories.tgz
tar xzf dailymail_stories.tgz

然后,将数据保存到 CSV 文件中,该文件可以加载到 tflite_model_maker 库中。代码基于在 tensorflow_datasets 中加载此数据的逻辑。我们无法直接使用 tensorflow_dataset,因为它不包含本 Colab 中使用的 url。

由于将数据处理成整个数据集的嵌入特征向量需要很长时间,因此默认情况下,出于演示目的,只选择了 CNN 和 Daily Mail 数据集的前 5% 的故事。您可以调整比例,或者尝试使用预构建的 TFLite 模型(包含 CNN 和 Daily Mail 数据集的 50% 的故事)进行搜索。

将亮点和 url 保存到 CSV 文件

构建文本搜索器模型

通过加载数据集、使用数据创建模型并导出 TFLite 模型来创建文本搜索器模型。

步骤 1. 加载数据集

Model Maker 使用 CSV 格式的文本数据集和每个文本字符串的相应元数据(例如本例中的 url)来嵌入文本字符串到特征向量中,使用用户指定的嵌入器模型。

在本演示中,我们使用 通用句子编码器 构建搜索器模型,这是一个最先进的句子嵌入模型,它已经从 Colab 中重新训练。该模型针对设备上推理性能进行了优化,仅需 6 毫秒即可嵌入查询字符串(在 Pixel 6 上测量)。或者,您可以使用 量化版本,它更小,但每次嵌入需要 38 毫秒。

wget -O universal_sentence_encoder.tflite https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite

创建一个 searcher.TextDataLoader 实例,并使用 data_loader.load_from_csv 方法加载数据集。此步骤大约需要 10 分钟,因为它会逐个生成每个文本的嵌入特征向量。您也可以尝试上传自己的 CSV 文件并加载它来构建自定义模型。

指定 CSV 文件中文本列和元数据列的名称。

  • 文本用于生成嵌入特征向量。
  • 元数据是在搜索特定文本时要显示的内容。

以下是上面生成的 CNN-DailyMail CSV 文件的前 4 行。

highlights urls
叙利亚官员:奥巴马爬上了树顶,不知道怎么下来。奥巴马给众议院和参议院议长写了一封信。
奥巴马将寻求国会批准对叙利亚采取军事行动。联合国发言人表示,目标是确定是否使用了化学武器,而不是由谁使用。
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/
博尔特赢得世锦赛第三枚金牌。带领牙买加队获得 4x100 米接力赛冠军。博尔特在世锦赛上获得第八枚金牌。牙买加队
在女子 4x100 米接力赛中获得双料冠军。
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold
该机构堪萨斯城办事处的员工是数百名“虚拟”员工中的一员。该员工去年往返美国大陆的旅行费用超过
24,000 美元。远程办公计划与所有 GSA 做法一样,正在接受审查。
http://www.cnn.com:80/2012/08/23/politics/gsa-hawaii-teleworking
最新消息:一位加拿大医生表示,她曾是 2010 年检查哈里·伯克哈特的一支团队的成员。最新消息:诊断结果为“自闭症、严重焦虑、创伤后应激障碍
和抑郁症”。官员表示,伯克哈特也涉嫌一起德国纵火案调查。检察官认为,这位德国公民在洛杉矶纵火了一系列火灾。
http://edition.cnn.com:80/2012/01/05/justice/california-arson/index.html?
data_loader = searcher.TextDataLoader.create("universal_sentence_encoder.tflite", l2_normalize=True)
data_loader.load_from_csv("cnn_dailymail.csv", text_column="highlights", metadata_column="urls")

对于图像用例,您可以创建一个 searcher.ImageDataLoader 实例,然后使用 data_loader.load_from_folder 从文件夹中加载图像。该 searcher.ImageDataLoader 实例需要由 TFLite 嵌入器模型创建,因为它将用于将查询编码为特征向量,并与 TFLite 搜索器模型一起导出。例如

data_loader = searcher.ImageDataLoader.create("mobilenet_v2_035_96_embedder_with_metadata.tflite")
data_loader.load_from_folder("food/")

步骤 2. 创建搜索器模型

  • 配置 ScaNN 选项。有关更多详细信息,请参阅 API 文档
  • 从数据和 ScaNN 选项创建搜索器模型。您可以查看 深入分析 以了解有关 ScaNN 算法的更多信息。
scann_options = searcher.ScaNNOptions(
      distance_measure="dot_product",
      tree=searcher.Tree(num_leaves=140, num_leaves_to_search=4),
      score_ah=searcher.ScoreAH(dimensions_per_block=1, anisotropic_quantization_threshold=0.2))
model = searcher.Searcher.create_from_data(data_loader, scann_options)

在上面的示例中,我们定义了以下选项

  • distance_measure:我们使用“点积”来衡量两个嵌入向量之间的距离。请注意,我们实际上计算的是点积值,以保留“越小越近”的概念。

  • tree:数据集被划分为 140 个分区(大约是数据大小的平方根),在检索过程中搜索其中的 4 个分区,这大约是数据集的 3%。

  • score_ah:我们将浮点嵌入量化为具有相同维度的 int8 值,以节省空间。

步骤 3. 导出 TFLite 模型

然后,您可以导出 TFLite 搜索器模型。

model.export(
      export_filename="searcher.tflite",
      userinfo="",
      export_format=searcher.ExportFormat.TFLITE)

在您的查询上测试 TFLite 模型

您可以使用自定义查询文本测试导出的 TFLite 模型。要使用搜索器模型查询文本,请初始化模型并使用文本短语运行搜索,如下所示

from tflite_support.task import text

# Initializes a TextSearcher object.
searcher = text.TextSearcher.create_from_file("searcher.tflite")

# Searches the input query.
results = searcher.search("The Airline Quality Rankings Report looks at the 14 largest U.S. airlines.")
print(results)

有关如何将模型集成到各种平台的更多信息,请参阅 任务库文档

阅读更多

有关更多信息,请参阅