构建您自己的任务 API

TensorFlow Lite 任务库 在相同的基础设施之上提供预构建的本机/Android/iOS API,该基础设施抽象了 TensorFlow。如果您的模型不受现有任务库的支持,您可以扩展任务 API 基础设施来构建自定义 API。

概述

任务 API 基础设施具有两层结构:底层 C++ 层封装了本机 TFLite 运行时,顶层 Java/ObjC 层通过 JNI 或本机包装器与 C++ 层通信。

在 C++ 中实现所有 TensorFlow 逻辑可最大限度地降低成本,最大限度地提高推理性能并简化跨平台的整体工作流程。

要创建任务类,请扩展 BaseTaskApi 以提供 TFLite 模型接口和任务 API 接口之间的转换逻辑,然后使用 Java/ObjC 实用程序创建相应的 API。隐藏了所有 TensorFlow 细节,您可以在应用程序中部署 TFLite 模型,而无需任何机器学习知识。

TensorFlow Lite 为大多数流行的 视觉和 NLP 任务 提供了一些预构建的 API。您可以使用任务 API 基础设施为其他任务构建自己的 API。

prebuilt_task_apis
图 1. 预构建的任务 API

使用任务 API 基础设施构建您自己的 API

C++ API

所有 TFLite 细节都在本机 API 中实现。通过使用其中一个工厂函数创建 API 对象,并通过调用接口中定义的函数获取模型结果。

示例用法

以下是一个使用 C++ BertQuestionAnswerer 的示例,用于 MobileBert

  char kBertModelPath[] = "path/to/model.tflite";
 
// Create the API from a model file
  std
::unique_ptr<BertQuestionAnswerer> question_answerer =
     
BertQuestionAnswerer::CreateFromFile(kBertModelPath);

 
char kContext[] = ...; // context of a question to be answered
 
char kQuestion[] = ...; // question to be answered
 
// ask a question
  std
::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
 
// answers[0].text is the best answer

构建 API

native_task_api
图 2. 本机任务 API

要构建 API 对象,您必须通过扩展 BaseTaskApi 提供以下信息

  • 确定 API I/O - 您的 API 应在不同平台上公开类似的输入/输出。例如,BertQuestionAnswerer 接受两个字符串 (std::string& context, std::string& question) 作为输入,并输出可能的答案和概率的向量,作为 std::vector<QaAnswer>。这是通过在 BaseTaskApi模板参数 中指定相应的类型来完成的。指定了模板参数后,BaseTaskApi::Infer 函数将具有正确的输入/输出类型。此函数可以直接由 API 客户端调用,但最好将其包装在模型特定函数中,在本例中为 BertQuestionAnswerer::Answer

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std
    ::vector<QaAnswer>, // OutputType
                                 
    const std::string&, const std::string& // InputTypes
                                 
    > {
     
    // Model specific function delegating calls to BaseTaskApi::Infer
      std
    ::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
       
    return Infer(context, question).value();
     
    }
    }
  • 提供 API I/O 和模型的输入/输出张量之间的转换逻辑 - 指定了输入和输出类型后,子类还需要实现类型化函数 BaseTaskApi::PreprocessBaseTaskApi::Postprocess。这两个函数提供来自 TFLite FlatBuffer输入输出。子类负责将 API I/O 中的值分配给 I/O 张量。请参阅 BertQuestionAnswerer 中的完整实现示例。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std
    ::vector<QaAnswer>, // OutputType
                                 
    const std::string&, const std::string& // InputTypes
                                 
    > {
     
    // Convert API input into tensors
      absl
    ::Status BertQuestionAnswerer::Preprocess(
       
    const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
       
    const std::string& context, const std::string& query // InputType of the API
     
    ) {
       
    // Perform tokenization on input strings
       
    ...
       
    // Populate IDs, Masks and SegmentIDs to corresponding input tensors
       
    PopulateTensor(input_ids, input_tensors[0]);
       
    PopulateTensor(input_mask, input_tensors[1]);
       
    PopulateTensor(segment_ids, input_tensors[2]);
       
    return absl::OkStatus();
     
    }

     
    // Convert output tensors into API output
     
    StatusOr<std::vector<QaAnswer>> // OutputType
     
    BertQuestionAnswerer::Postprocess(
       
    const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
     
    ) {
       
    // Get start/end logits of prediction result from output tensors
        std
    ::vector<float> end_logits;
        std
    ::vector<float> start_logits;
       
    // output_tensors[0]: end_logits FLOAT[1, 384]
       
    PopulateVector(output_tensors[0], &end_logits);
       
    // output_tensors[1]: start_logits FLOAT[1, 384]
       
    PopulateVector(output_tensors[1], &start_logits);
       
    ...
        std
    ::vector<QaAnswer::Pos> orig_results;
       
    // Look up the indices from vocabulary file and build results
       
    ...
       
    return orig_results;
     
    }
    }
  • 创建 API 的工厂函数 - 初始化 tflite::Interpreter 需要一个模型文件和一个 OpResolverTaskAPIFactory 提供了创建 BaseTaskApi 实例的实用函数。

    您还必须提供与模型关联的任何文件。例如,BertQuestionAnswerer 还可以为其分词器的词汇表提供一个额外的文件。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std
    ::vector<QaAnswer>, // OutputType
                                 
    const std::string&, const std::string& // InputTypes
                                 
    > {
     
    // Factory function to create the API instance
     
    StatusOr<std::unique_ptr<QuestionAnswerer>>
     
    BertQuestionAnswerer::CreateBertQuestionAnswerer(
         
    const std::string& path_to_model, // model to passed to TaskApiFactory
         
    const std::string& path_to_vocab  // additional model specific files
     
    ) {
       
    // Creates an API object by calling one of the utils from TaskAPIFactory
        std
    ::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN
    (
            api_to_init
    ,
            core
    ::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model
    ,
                absl
    ::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads
    ));

       
    // Perform additional model specific initializations
       
    // In this case building a vocabulary vector from the vocab file.
        api_to_init
    ->InitializeVocab(path_to_vocab);
       
    return api_to_init;
     
    }
    }

Android API

通过定义 Java/Kotlin 接口并将逻辑通过 JNI 委托给 C++ 层来创建 Android API。Android API 要求首先构建本地 API。

示例用法

以下是一个使用 Java BertQuestionAnswerer 用于 MobileBert 的示例。

  String BERT_MODEL_FILE = "path/to/model.tflite";
 
String VOCAB_FILE = "path/to/vocab.txt";
 
// Create the API from a model file and vocabulary file
   
BertQuestionAnswerer bertQuestionAnswerer =
       
BertQuestionAnswerer.createBertQuestionAnswerer(
           
ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

 
String CONTEXT = ...; // context of a question to be answered
 
String QUESTION = ...; // question to be answered
 
// ask a question
 
List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
 
// answers.get(0).text is the best answer

构建 API

android_task_api
图 3. Android 任务 API

与本地 API 类似,要构建 API 对象,客户端需要通过扩展 BaseTaskApi 来提供以下信息,该类为所有 Java 任务 API 提供了 JNI 处理。

  • 确定 API I/O - 这通常反映本地接口。例如,BertQuestionAnswerer(String context, String question) 作为输入,并输出 List<QaAnswer>。实现调用具有类似签名的私有本地函数,除了它有一个额外的参数 long nativeHandle,它是从 C++ 返回的指针。

    class BertQuestionAnswerer extends BaseTaskApi {
     
    public List<QaAnswer> answer(String context, String question) {
       
    return answerNative(getNativeHandle(), context, question);
     
    }

     
    private static native List<QaAnswer> answerNative(
                                           
    long nativeHandle, // C++ pointer
                                           
    String context, String question // API I/O
                                           
    );

    }
  • 创建 API 的工厂函数 - 这也反映了本地工厂函数,除了 Android 工厂函数还需要接受 Context 用于文件访问。实现调用 TaskJniUtils 中的某个实用程序来构建相应的 C++ API 对象,并将它的指针传递给 BaseTaskApi 构造函数。

      class BertQuestionAnswerer extends BaseTaskApi {
       
    private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                 
    "bert_question_answerer_jni";

       
    // Extending super constructor by providing the
       
    // native handle(pointer of corresponding C++ API object)
       
    private BertQuestionAnswerer(long nativeHandle) {
         
    super(nativeHandle);
       
    }

       
    public static BertQuestionAnswerer createBertQuestionAnswerer(
                                           
    Context context, // Accessing Android files
                                           
    String pathToModel, String pathToVocab) {
         
    return new BertQuestionAnswerer(
             
    // The util first try loads the JNI module with name
             
    // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
             
    // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
             
    // is called with the buffer for a C++ API object pointer
             
    TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context
    ,
                 
    BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME
    ,
                  pathToModel
    ,
                  pathToVocab
    ));
       
    }

       
    // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
       
    // returns C++ API object pointer casted to long
       
    private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);

     
    }
  • 实现本地函数的 JNI 模块 - 所有 Java 本地方法都是通过调用 JNI 模块中的相应本地函数来实现的。工厂函数将创建一个本地 API 对象,并将它的指针作为 long 类型返回给 Java。在以后对 Java API 的调用中,long 类型指针将被传递回 JNI 并被强制转换为本地 API 对象。然后,本地 API 结果将被转换回 Java 结果。

    例如,这就是 bert_question_answerer_jni 的实现方式。

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
     
    extern "C" JNIEXPORT jlong JNICALL
     
    Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
         
    JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
       
    // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl
    ::string_view model =
           
    GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));

       
    // Creates the native API object
        absl
    ::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
           
    BertQuestionAnswerer::CreateFromBuffer(
                model
    .data(), model.size());
       
    if (status.ok()) {
         
    // converts the object pointer to jlong and return to Java.
         
    return reinterpret_cast<jlong>(status->release());
       
    } else {
         
    return kInvalidPointer;
       
    }
     
    }

     
    // Implements BertQuestionAnswerer::answerNative
     
    extern "C" JNIEXPORT jobject JNICALL
     
    Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
     
    JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
     
    // Convert long to native API object pointer
     
    QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);

     
    // Calls the native API
      std
    ::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             
    JStringToString(env, question));

     
    // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class
    =
        env
    ->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor
    =
        env
    ->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
     
    return ConvertVectorToArrayList<QaAnswer>(
        env
    , results,
       
    [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text
    = env->NewStringUTF(ans.text.data());
          jobject qa_answer
    =
              env
    ->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans
    .pos.end, ans.pos.logit);
          env
    ->DeleteLocalRef(text);
         
    return qa_answer;
       
    });
     
    }

     
    // Implements BaseTaskApi::deinitJni by delete the native object
     
    extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
         
    JNIEnv* env, jobject thiz, jlong native_handle) {
       
    delete reinterpret_cast<QuestionAnswerer*>(native_handle);
     
    }

iOS API

通过将本地 API 对象包装到 ObjC API 对象中来创建 iOS API。创建的 API 对象可以在 ObjC 或 Swift 中使用。iOS API 要求首先构建本地 API。

示例用法

以下是一个使用 ObjC TFLBertQuestionAnswerer 用于 MobileBert 的 Swift 示例。

  static let mobileBertModelPath = "path/to/model.tflite";
 
// Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath
: mobileBertModelPath)

 
static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context
: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
 
// answers.[0].text is the best answer

构建 API

ios_task_api
图 4. iOS 任务 API

iOS API 是在本地 API 之上的一个简单的 ObjC 包装器。按照以下步骤构建 API

  • 定义 ObjC 包装器 - 定义一个 ObjC 类并将实现委托给相应的本地 API 对象。请注意,由于 Swift 无法与 C++ 交互,因此本地依赖项只能出现在 .mm 文件中。

    • .h 文件
      @interface TFLBertQuestionAnswerer : NSObject

     
    // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
     
    + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath
    :(NSString*)vocabPath
          NS_SWIFT_NAME
    (mobilebertQuestionAnswerer(modelPath:vocabPath:));

     
    // Delegate calls to the native BertQuestionAnswerer::Answer
     
    - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question
    :(NSString*)question
          NS_SWIFT_NAME
    (answer(context:question:));
    }
    • .mm 文件
      using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer;

     
    @implementation TFLBertQuestionAnswerer {
       
    // define an iVar for the native API object
        std
    ::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
     
    }

     
    // Initialize the native API object
     
    + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath
    :(NSString *)vocabPath {
        absl
    ::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
           
    BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                               
    MakeString(vocabPath));
        _GTMDevAssert
    (cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
       
    return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer
    :std::move(cQuestionAnswerer.value())];
     
    }

     
    // Calls the native API and converts C++ results into ObjC results
     
    - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std
    ::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer
    ->Answer(MakeString(context), MakeString(question));
       
    return [self arrayFromVector:results];
     
    }
    }