本教程将向您展示如何使用 TensorFlow Lite 构建 Android 应用程序,以提供对以自然语言文本结构化的问题的答案。该 示例应用程序 使用 BERT 问答器 (BertQuestionAnswerer
) 自然语言 (NL) 任务库 中的 API 来启用问答机器学习模型。该应用程序专为物理 Android 设备设计,但也可以在设备模拟器上运行。
如果您正在更新现有项目,可以使用示例应用程序作为参考或模板。有关如何将问答添加到现有应用程序的说明,请参阅 更新和修改您的应用程序。
问答概述
问答是机器学习任务,用于回答以自然语言提出的问题。经过训练的问答模型接收文本段落和问题作为输入,并尝试根据其对段落中信息的解释来回答问题。
问答模型是在问答数据集上训练的,该数据集包含阅读理解数据集以及基于不同文本段落的问答对。
有关本教程中模型生成方式的更多信息,请参阅 使用 TensorFlow Lite 模型制作器进行 BERT 问答 教程。
模型和数据集
示例应用程序使用 Mobile BERT Q&A (mobilebert
) 模型,它是 BERT (来自 Transformer 的双向编码器表示) 的更轻量级且更快的版本。有关 mobilebert
的更多信息,请参阅 MobileBERT:一种针对资源受限设备的紧凑型任务无关 BERT 研究论文。
该 mobilebert
模型使用斯坦福问答数据集 (SQuAD) 数据集进行训练,该数据集是一个阅读理解数据集,包含来自维基百科的文章以及每篇文章的一组问答对。
设置和运行示例应用程序
要设置问答应用程序,请从 GitHub 下载示例应用程序,并使用 Android Studio 运行它。
系统要求
- Android Studio 版本 2021.1.1 (Bumblebee) 或更高版本。
- Android SDK 版本 31 或更高版本
- 具有最低操作系统版本 SDK 21 (Android 7.0 - 牛轧糖) 的 Android 设备,并启用了 开发者模式,或 Android 模拟器。
获取示例代码
创建示例代码的本地副本。您将使用此代码在 Android Studio 中创建项目并运行示例应用程序。
要克隆和设置示例代码
- 克隆 git 存储库
git clone https://github.com/tensorflow/examples.git
- 可选地,配置您的 git 实例以使用稀疏检出,这样您将只拥有问答示例应用程序的文件
cd examples git sparse-checkout init --cone git sparse-checkout set lite/examples/bert_qa/android
导入并运行项目
从下载的示例代码创建项目,构建项目,然后运行它。
要导入和构建示例代码项目
- 启动 Android Studio。
- 从 Android Studio 中,选择 **文件 > 新建 > 导入项目**。
- 导航到包含 build.gradle 文件的示例代码目录(
.../examples/lite/examples/bert_qa/android/build.gradle
)并选择该目录。 - 如果 Android Studio 请求 Gradle 同步,请选择确定。
- 确保您的 Android 设备已连接到您的计算机,并且已启用开发者模式。单击绿色的
运行
箭头。
如果您选择了正确的目录,Android Studio 将创建一个新项目并构建它。此过程可能需要几分钟,具体取决于您的计算机速度以及您是否已将 Android Studio 用于其他项目。构建完成后,Android Studio 会在 **构建输出** 状态面板中显示 BUILD SUCCESSFUL
消息。
要运行项目
- 从 Android Studio 中,通过选择 **运行 > 运行...** 运行项目。
- 选择一个已连接的 Android 设备(或模拟器)来测试应用程序。
使用应用程序
在 Android Studio 中运行项目后,应用程序会自动在已连接的设备或设备模拟器上打开。
要使用问答示例应用程序
- 从主题列表中选择一个主题。
- 选择一个建议的问题或在文本框中输入您自己的问题。
- 切换橙色箭头以运行模型。
应用程序尝试从文章文本中识别问题的答案。如果模型在文章中检测到答案,应用程序会突出显示与用户相关的文本范围。
您现在拥有一个功能完善的问答应用程序。使用以下部分更好地了解示例应用程序的工作原理以及如何在生产应用程序中实现问答功能
示例应用程序的工作原理
应用程序使用 BertQuestionAnswerer
API,该 API 位于 自然语言 (NL) 任务库 包中。MobileBERT 模型使用 TensorFlow Lite 模型制作器 进行训练。应用程序默认在 CPU 上运行,可以选择使用 GPU 或 NNAPI 代理进行硬件加速。
以下文件和目录包含此应用程序的关键代码
- BertQaHelper.kt - 初始化问答器并处理模型和代理选择。
- QaFragment.kt - 处理和格式化结果。
- MainActivity.kt - 提供应用程序的组织逻辑。
修改您的应用程序
以下部分说明了修改您自己的 Android 应用程序以运行示例应用程序中显示的模型的关键步骤。这些说明使用示例应用程序作为参考点。您自己的应用程序所需的具体更改可能与示例应用程序有所不同。
打开或创建一个 Android 项目
您需要在 Android Studio 中创建一个 Android 开发项目才能继续执行以下说明。按照以下说明打开现有项目或创建一个新项目。
要打开现有的 Android 开发项目
- 在 Android Studio 中,选择 *文件 > 打开* 并选择一个现有项目。
要创建一个基本的 Android 开发项目
- 按照 Android Studio 中的说明 创建一个基本项目。
有关使用 Android Studio 的更多信息,请参阅 Android Studio 文档。
添加项目依赖项
在您自己的应用程序中,添加特定的项目依赖项以运行 TensorFlow Lite 机器学习模型并访问实用程序函数。这些函数将数据(例如字符串)转换为可以由模型处理的张量数据格式。以下说明解释了如何将所需的项目和模块依赖项添加到您自己的 Android 应用程序项目中。
要添加模块依赖项
在使用 TensorFlow Lite 的模块中,更新模块的
build.gradle
文件以包含以下依赖项。在示例应用程序中,依赖项位于 app/build.gradle 中
dependencies { ... // Import tensorflow library implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0' // Import the GPU delegate plugin Library for GPU inference implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0' implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0' }
该项目必须包含文本任务库(
tensorflow-lite-task-text
)。如果您想修改此应用程序以在图形处理单元 (GPU) 上运行,GPU 库(
tensorflow-lite-gpu-delegate-plugin
)提供了在 GPU 上运行应用程序的基础结构,而 Delegate(tensorflow-lite-gpu
)提供了兼容性列表。在 Android Studio 中,通过选择:**文件 > 使用 Gradle 文件同步项目** 来同步项目依赖项。
初始化 ML 模型
在您的 Android 应用程序中,您必须使用参数初始化 TensorFlow Lite 机器学习模型,然后才能使用该模型运行预测。
TensorFlow Lite 模型存储为 *.tflite
文件。模型文件包含预测逻辑,通常包括有关如何解释预测结果的 元数据。通常,模型文件存储在开发项目的 src/main/assets
目录中,如代码示例所示
<project>/src/main/assets/mobilebert_qa.tflite
为了方便和代码可读性,示例声明了一个伴随对象,该对象定义了模型的设置。
要在您的应用程序中初始化模型
创建一个伴随对象以定义模型的设置。在示例应用程序中,此对象位于 BertQaHelper.kt 中
companion object { private const val BERT_QA_MODEL = "mobilebert.tflite" private const val TAG = "BertQaHelper" const val DELEGATE_CPU = 0 const val DELEGATE_GPU = 1 const val DELEGATE_NNAPI = 2 }
通过构建
BertQaHelper
对象创建模型的设置,并使用bertQuestionAnswerer
构建 TensorFlow Lite 对象。在示例应用程序中,这位于 BertQaHelper.kt 中的
setupBertQuestionAnswerer()
函数中class BertQaHelper( ... ) { ... init { setupBertQuestionAnswerer() } fun clearBertQuestionAnswerer() { bertQuestionAnswerer = null } private fun setupBertQuestionAnswerer() { val baseOptionsBuilder = BaseOptions.builder().setNumThreads(numThreads) ... val options = BertQuestionAnswererOptions.builder() .setBaseOptions(baseOptionsBuilder.build()) .build() try { bertQuestionAnswerer = BertQuestionAnswerer.createFromFileAndOptions(context, BERT_QA_MODEL, options) } catch (e: IllegalStateException) { answererListener ?.onError("Bert Question Answerer failed to initialize. See error logs for details") Log.e(TAG, "TFLite failed to load model with error: " + e.message) } } ... }
启用硬件加速(可选)
在您的应用程序中初始化 TensorFlow Lite 模型时,您应该考虑使用硬件加速功能来加快模型的预测计算。TensorFlow Lite 代理 是软件模块,它们使用移动设备上的专用处理硬件(例如图形处理单元 (GPU) 或张量处理单元 (TPU))来加速机器学习模型的执行。
要在您的应用程序中启用硬件加速
创建一个变量以定义应用程序将使用的代理。在示例应用程序中,此变量位于 BertQaHelper.kt 的开头
var currentDelegate: Int = 0
创建一个代理选择器。在示例应用程序中,代理选择器位于 BertQaHelper.kt 中的
setupBertQuestionAnswerer
函数中when (currentDelegate) { DELEGATE_CPU -> { // Default } DELEGATE_GPU -> { if (CompatibilityList().isDelegateSupportedOnThisDevice) { baseOptionsBuilder.useGpu() } else { answererListener?.onError("GPU is not supported on this device") } } DELEGATE_NNAPI -> { baseOptionsBuilder.useNnapi() } }
建议使用代理运行 TensorFlow Lite 模型,但不是必需的。有关使用 TensorFlow Lite 代理的更多信息,请参阅 TensorFlow Lite 代理。
准备模型数据
在您的 Android 应用程序中,您的代码通过将现有数据(例如原始文本)转换为可以由模型处理的 张量 数据格式来为模型提供数据。传递给模型的张量必须具有与用于训练模型的数据格式匹配的特定维度或形状。此问答应用程序接受 字符串 作为文本文章和问题的输入。模型无法识别特殊字符和非英语单词。
要为模型提供文章文本数据
使用
LoadDataSetClient
对象将文章文本数据加载到应用程序中。在示例应用程序中,这位于 LoadDataSetClient.kt 中fun loadJson(): DataSet? { var dataSet: DataSet? = null try { val inputStream: InputStream = context.assets.open(JSON_DIR) val bufferReader = inputStream.bufferedReader() val stringJson: String = bufferReader.use { it.readText() } val datasetType = object : TypeToken<DataSet>() {}.type dataSet = Gson().fromJson(stringJson, datasetType) } catch (e: IOException) { Log.e(TAG, e.message.toString()) } return dataSet }
使用
DatasetFragment
对象列出每篇文章的标题并启动 **TFL 问答** 屏幕。在示例应用程序中,这位于 DatasetFragment.kt 中class DatasetFragment : Fragment() { ... override fun onViewCreated(view: View, savedInstanceState: Bundle?) { super.onViewCreated(view, savedInstanceState) val client = LoadDataSetClient(requireActivity()) client.loadJson()?.let { titles = it.getTitles() } ... } ... }
使用
DatasetAdapter
对象中的onCreateViewHolder
函数来呈现每篇文章的标题。在示例应用程序中,这位于 DatasetAdapter.kt 中override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder { val binding = ItemDatasetBinding.inflate( LayoutInflater.from(parent.context), parent, false ) return ViewHolder(binding) }
要为模型提供用户问题
使用
QaAdapter
对象将问题提供给模型。在示例应用程序中,这位于 QaAdapter.kt 中class QaAdapter(private val question: List<String>, private val select: (Int) -> Unit) : RecyclerView.Adapter<QaAdapter.ViewHolder>() { inner class ViewHolder(private val binding: ItemQuestionBinding) : RecyclerView.ViewHolder(binding.root) { init { binding.tvQuestionSuggestion.setOnClickListener { select.invoke(adapterPosition) } } fun bind(question: String) { binding.tvQuestionSuggestion.text = question } } ... }
运行预测
在您的 Android 应用程序中,一旦您初始化了 BertQuestionAnswerer 对象,您就可以开始以自然语言文本的形式将问题输入模型。模型尝试在文章中识别答案。
要运行预测
创建一个
answer
函数,该函数运行模型并测量识别答案所需的时间(inferenceTime
)。在示例应用程序中,answer
函数位于 BertQaHelper.kt 中fun answer(contextOfQuestion: String, question: String) { if (bertQuestionAnswerer == null) { setupBertQuestionAnswerer() } var inferenceTime = SystemClock.uptimeMillis() val answers = bertQuestionAnswerer?.answer(contextOfQuestion, question) inferenceTime = SystemClock.uptimeMillis() - inferenceTime answererListener?.onResults(answers, inferenceTime) }
将来自
answer
的结果传递给侦听器对象。interface AnswererListener { fun onError(error: String) fun onResults( results: List<QaAnswer>?, inferenceTime: Long ) }
处理模型输出
输入问题后,模型将在文章中提供最多五个可能的答案。
要获取模型的结果
为侦听器对象创建一个
onResult
函数来处理输出。在示例应用程序中,侦听器对象位于 BertQaHelper.kt 中interface AnswererListener { fun onError(error: String) fun onResults( results: List<QaAnswer>?, inferenceTime: Long ) }
根据结果突出显示段落中的部分。在示例应用程序中,这位于 QaFragment.kt 中。
override fun onResults(results: List<QaAnswer>?, inferenceTime: Long) { results?.first()?.let { highlightAnswer(it.text) } fragmentQaBinding.tvInferenceTime.text = String.format( requireActivity().getString(R.string.bottom_view_inference_time), inferenceTime ) }
模型返回一组结果后,您的应用程序可以通过向用户展示结果或执行其他逻辑来处理这些预测。
下一步
- 使用 TensorFlow Lite 模型制作器中的问答 教程从头开始训练和实现模型。
- 探索更多 TensorFlow 的文本处理工具。
- 在 TensorFlow Hub 上下载其他 BERT 模型。
- 在 示例 中探索 TensorFlow Lite 的各种用途。