使用 Android 进行文本分类

本教程将向您展示如何使用 TensorFlow Lite 构建一个 Android 应用程序,以对自然语言文本进行分类。此应用程序专为物理 Android 设备设计,但也可以在设备模拟器上运行。

示例应用程序 使用 TensorFlow Lite 将文本分类为正面或负面,使用 自然语言 (NL) 任务库 来启用文本分类机器学习模型的执行。

如果您要更新现有项目,可以使用示例应用程序作为参考或模板。有关如何将文本分类添加到现有应用程序的说明,请参阅 更新和修改您的应用程序

文本分类概述

文本分类是机器学习任务,用于将一组预定义的类别分配给开放式文本。文本分类模型是在自然语言文本语料库上训练的,其中单词或短语被手动分类。

训练后的模型接收文本作为输入,并尝试根据其训练分类的已知类别集对文本进行分类。例如,本示例中的模型接受一段文本,并确定文本的情感是正面还是负面。对于每段文本,文本分类模型都会输出一个分数,该分数表示文本被正确分类为正面或负面的置信度。

有关本教程中模型如何生成的更多信息,请参阅 使用 TensorFlow Lite Model Maker 进行文本分类 教程。

模型和数据集

本教程使用使用 SST-2(斯坦福情感树库)数据集训练的模型。SST-2 包含 67,349 条电影评论用于训练,872 条电影评论用于测试,每条评论都分类为正面或负面。此应用程序中使用的模型使用 TensorFlow Lite Model Maker 工具进行训练。

示例应用程序使用以下预训练模型

  • 平均词向量 (NLClassifier) - 任务库的 NLClassifier 将输入文本分类为不同的类别,并且可以处理大多数文本分类模型。

  • MobileBERT (BertNLClassifier) - 任务库的 BertNLClassifier 与 NLClassifier 类似,但专门针对需要图外 Wordpiece 和 Sentencepiece 分词的情况。

设置和运行示例应用程序

要设置文本分类应用程序,请从 GitHub 下载示例应用程序,并使用 Android Studio 运行它。

系统要求

  • Android Studio 版本 2021.1.1(Bumblebee)或更高版本。
  • Android SDK 版本 31 或更高版本
  • 需要运行此示例应用的 Android 设备,其操作系统版本至少为 SDK 21(Android 7.0 - 牛轧糖),并已启用 开发者模式,或 Android 模拟器。

获取示例代码

创建示例代码的本地副本。您将使用此代码在 Android Studio 中创建项目并运行示例应用程序。

要克隆和设置示例代码

  1. 克隆 git 存储库
    git clone https://github.com/tensorflow/examples.git
    
  2. 可选地,配置您的 git 实例以使用稀疏签出,这样您将只拥有文本分类示例应用程序的文件
    cd examples
    git sparse-checkout init --cone
    git sparse-checkout set lite/examples/text_classification/android
    

导入并运行项目

从下载的示例代码创建项目,构建项目,然后运行它。

要导入和构建示例代码项目

  1. 启动 Android Studio
  2. 在 Android Studio 中,选择 **文件 > 新建 > 导入项目**。
  3. 导航到包含 build.gradle 文件的示例代码目录 (.../examples/lite/examples/text_classification/android/build.gradle) 并选择该目录。
  4. 如果 Android Studio 请求 Gradle 同步,请选择确定。
  5. 确保您的 Android 设备已连接到您的计算机,并且已启用开发者模式。单击绿色的 运行 箭头。

如果您选择了正确的目录,Android Studio 将创建一个新项目并构建它。此过程可能需要几分钟,具体取决于计算机的速度以及您是否已将 Android Studio 用于其他项目。构建完成后,Android Studio 将在 **构建输出** 状态面板中显示 构建成功 消息。

要运行项目

  1. 在 Android Studio 中,通过选择 **运行 > 运行…** 来运行项目。
  2. 选择连接的 Android 设备(或模拟器)以测试应用程序。

使用应用程序

Text classification example app in Android

在 Android Studio 中运行项目后,应用程序将自动在连接的设备或设备模拟器上打开。

要使用文本分类器

  1. 在文本框中输入一段文本。
  2. 从 **委托** 下拉菜单中,选择 CPUNNAPI
  3. 通过选择 AverageWordVecMobileBERT 来指定模型。
  4. 选择 **分类**。

应用程序输出一个 *正* 分数和一个 *负* 分数。这两个分数将加起来为 1,并衡量输入文本的情感是正面还是负面的可能性。数字越大,表示置信度越高。

您现在拥有一个功能完备的文本分类应用程序。使用以下部分更好地了解示例应用程序的工作原理,以及如何在生产应用程序中实现文本分类功能

示例应用程序的工作原理

应用程序使用 自然语言 (NL) 任务库 包来实现文本分类模型。这两个模型,平均词向量和 MobileBERT,是使用 TensorFlow Lite 模型制作器 训练的。应用程序默认在 CPU 上运行,可以选择使用 NNAPI 委托进行硬件加速。

以下文件和目录包含此文本分类应用程序的关键代码

修改您的应用程序

以下部分说明了修改您自己的 Android 应用程序以运行示例应用程序中显示的模型的关键步骤。这些说明使用示例应用程序作为参考点。您自己的应用程序所需的具体更改可能与示例应用程序不同。

打开或创建 Android 项目

您需要在 Android Studio 中有一个 Android 开发项目才能继续执行以下说明。按照以下说明打开现有项目或创建一个新项目。

要打开现有的 Android 开发项目

  • 在 Android Studio 中,选择 *文件 > 打开* 并选择现有项目。

要创建一个基本的 Android 开发项目

有关使用 Android Studio 的更多信息,请参阅 Android Studio 文档

添加项目依赖项

在您自己的应用程序中,您必须添加特定的项目依赖项才能运行 TensorFlow Lite 机器学习模型,并访问将数据(例如字符串)转换为张量数据格式的实用程序函数,该格式可以由您使用的模型处理。

以下说明解释了如何在您自己的 Android 应用程序项目中添加所需的项目和模块依赖项。

要添加模块依赖项

  1. 在使用 TensorFlow Lite 的模块中,更新模块的 build.gradle 文件以包含以下依赖项。

    在示例应用程序中,依赖项位于 app/build.gradle

    dependencies {
      ...
      implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.0'
    }
    

    该项目必须包含文本任务库 (tensorflow-lite-task-text)。

    如果您想修改此应用程序以在图形处理单元 (GPU) 上运行,GPU 库 (tensorflow-lite-gpu-delegate-plugin) 提供了在 GPU 上运行应用程序的基础设施,而 Delegate (tensorflow-lite-gpu) 提供了兼容性列表。在本教程的范围之外,不会在 GPU 上运行此应用程序。

  2. 在 Android Studio 中,通过选择:**文件 > 与 Gradle 文件同步项目** 来同步项目依赖项。

初始化 ML 模型

在您的 Android 应用程序中,您必须使用参数初始化 TensorFlow Lite 机器学习模型,然后才能使用该模型运行预测。

TensorFlow Lite 模型存储为 *.tflite 文件。模型文件包含预测逻辑,通常包括有关如何解释预测结果的 元数据,例如预测类别名称。通常,模型文件存储在开发项目的 src/main/assets 目录中,如代码示例所示

  • <project>/src/main/assets/mobilebert.tflite
  • <project>/src/main/assets/wordvec.tflite

为了方便起见和代码可读性,示例声明了一个伴随对象,该对象定义了模型的设置。

要在您的应用程序中初始化模型

  1. 创建一个伴随对象来定义模型的设置。在示例应用程序中,此对象位于 TextClassificationHelper.kt

    companion object {
      const val DELEGATE_CPU = 0
      const val DELEGATE_NNAPI = 1
      const val WORD_VEC = "wordvec.tflite"
      const val MOBILEBERT = "mobilebert.tflite"
    }
    
  2. 通过构建分类器对象来创建模型的设置,并使用 BertNLClassifierNLClassifier 构造 TensorFlow Lite 对象。

    在示例应用程序中,这位于 TextClassificationHelper.kt 中的 initClassifier 函数中。

    fun initClassifier() {
      ...
      if( currentModel == MOBILEBERT ) {
        ...
        bertClassifier = BertNLClassifier.createFromFileAndOptions(
          context,
          MOBILEBERT,
          options)
      } else if (currentModel == WORD_VEC) {
          ...
          nlClassifier = NLClassifier.createFromFileAndOptions(
            context,
            WORD_VEC,
            options)
      }
    }
    

启用硬件加速(可选)

在您的应用程序中初始化 TensorFlow Lite 模型时,您应该考虑使用硬件加速功能来加快模型的预测计算速度。TensorFlow Lite 委托 是软件模块,它们使用移动设备上的专用处理硬件(例如图形处理单元 (GPU) 或张量处理单元 (TPU))来加速机器学习模型的执行。

要在您的应用程序中启用硬件加速

  1. 创建一个变量来定义应用程序将使用的委托。在示例应用程序中,此变量位于 TextClassificationHelper.kt 的开头。

    var currentDelegate: Int = 0
    
  2. 创建一个委托选择器。在示例应用程序中,委托选择器位于 TextClassificationHelper.kt 中的 initClassifier 函数中。

    val baseOptionsBuilder = BaseOptions.builder()
    when (currentDelegate) {
       DELEGATE_CPU -> {
           // Default
       }
       DELEGATE_NNAPI -> {
           baseOptionsBuilder.useNnapi()
       }
    }
    

建议使用委托来运行 TensorFlow Lite 模型,但不是必需的。有关使用 TensorFlow Lite 委托的更多信息,请参阅 TensorFlow Lite 委托

准备模型数据

在您的 Android 应用程序中,您的代码通过将现有数据(例如原始文本)转换为可以由您的模型处理的 张量 数据格式来为模型提供数据,以便模型进行解释。传递给模型的张量中的数据必须具有与用于训练模型的数据格式匹配的特定维度或形状。

此文本分类应用程序接受 字符串 作为输入,并且模型仅使用英语语料库进行训练。在推理过程中,特殊字符和非英语单词将被忽略。

要为模型提供文本数据

  1. 确保 initClassifier 函数包含委托和模型的代码,如 初始化 ML 模型启用硬件加速 部分所述。

  2. 使用 init 块来调用 initClassifier 函数。在示例应用程序中,init 位于 TextClassificationHelper.kt

    init {
      initClassifier()
    }
    

运行预测

在您的 Android 应用程序中,一旦您初始化了 BertNLClassifierNLClassifier 对象,您就可以开始为模型提供输入文本,以便模型将其分类为“正面”或“负面”。

要运行预测

  1. 创建一个 classify 函数,该函数使用选定的分类器 (currentModel) 并测量对输入文本进行分类所花费的时间 (inferenceTime)。在示例应用程序中,classify 函数位于 TextClassificationHelper.kt

    fun classify(text: String) {
      executor = ScheduledThreadPoolExecutor(1)
    
      executor.execute {
        val results: List<Category>
        // inferenceTime is the amount of time, in milliseconds, that it takes to
        // classify the input text.
        var inferenceTime = SystemClock.uptimeMillis()
    
        // Use the appropriate classifier based on the selected model
        if(currentModel == MOBILEBERT) {
          results = bertClassifier.classify(text)
        } else {
          results = nlClassifier.classify(text)
        }
    
        inferenceTime = SystemClock.uptimeMillis() - inferenceTime
    
        listener.onResult(results, inferenceTime)
      }
    }
    
  2. 将来自 classify 的结果传递给侦听器对象。

    fun classify(text: String) {
      ...
      listener.onResult(results, inferenceTime)
    }
    

处理模型输出

输入一行文本后,模型会生成一个预测分数,以浮点数形式表示,介于 0 到 1 之间,分别代表“正面”和“负面”类别。

要获取模型的预测结果

  1. 为监听器对象创建一个 onResult 函数来处理输出。在示例应用程序中,监听器对象位于 MainActivity.kt

    private val listener = object : TextClassificationHelper.TextResultsListener {
      override fun onResult(results: List<Category>, inferenceTime: Long) {
        runOnUiThread {
          activityMainBinding.bottomSheetLayout.inferenceTimeVal.text =
            String.format("%d ms", inferenceTime)
    
          adapter.resultsList = results.sortedByDescending {
            it.score
          }
    
          adapter.notifyDataSetChanged()
        }
      }
      ...
    }
    
  2. 在监听器对象中添加一个 onError 函数来处理错误

      private val listener = object : TextClassificationHelper.TextResultsListener {
        ...
        override fun onError(error: String) {
          Toast.makeText(this@MainActivity, error, Toast.LENGTH_SHORT).show()
        }
      }
    

模型返回一组预测结果后,您的应用程序可以通过向用户展示结果或执行其他逻辑来处理这些预测。示例应用程序在用户界面中列出了预测分数。

下一步