安装 TensorFlow Java

TensorFlow Java 可以在任何 JVM 上运行,用于构建、训练和部署机器学习模型。它支持 CPU 和 GPU 执行,以图形或急切模式运行,并提供丰富的 API,以便在 JVM 环境中使用 TensorFlow。Java 和其他 JVM 语言(如 Scala 和 Kotlin)在全球大小企业中被广泛使用,这使得 TensorFlow Java 成为大规模采用机器学习的战略选择。

要求

TensorFlow Java 运行在 Java 8 及更高版本上,并开箱即用地支持以下平台

  • Ubuntu 16.04 或更高版本;64 位,x86
  • macOS 10.12.6 (Sierra) 或更高版本;64 位,x86
  • Windows 7 或更高版本;64 位,x86

版本

TensorFlow Java 具有自己的发布周期,独立于 TensorFlow 运行时。因此,它的版本与它运行的 TensorFlow 运行时版本不匹配。请参阅 TensorFlow Java 版本表 以列出所有可用版本及其与 TensorFlow 运行时的映射。

工件

几种方法 可以将 TensorFlow Java 添加到您的项目中。最简单的方法是添加对 tensorflow-core-platform 工件的依赖项,其中包括 TensorFlow Java 核心 API 及其在所有支持平台上运行所需的本机依赖项。

您还可以选择以下扩展之一,而不是纯 CPU 版本

  • tensorflow-core-platform-mkl:在所有平台上支持 Intel® MKL-DNN
  • tensorflow-core-platform-gpu:在 Linux 和 Windows 平台上支持 CUDA®
  • tensorflow-core-platform-mkl-gpu:在 Linux 平台上支持 Intel® MKL-DNN 和 CUDA®。

此外,可以添加对 tensorflow-framework 库的单独依赖项,以从 JVM 上基于 TensorFlow 的机器学习的丰富实用程序集受益。

使用 Maven 安装

要将 TensorFlow 包含在您的 Maven 应用程序中,请将对它的 工件 的依赖项添加到项目的 pom.xml 文件中。例如,

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow-core-platform</artifactId>
  <version>0.3.3</version>
</dependency>

减少依赖项数量

重要的是要注意,添加对 tensorflow-core-platform 工件的依赖项将导入所有支持平台的本机库,这会显着增加项目的规模。

如果您希望针对可用平台的子集,则可以使用 Maven 依赖项排除 功能从其他平台中排除不必要的工件。

另一种选择要将哪些平台包含在应用程序中的方法是在 Maven 命令行或 pom.xml 中设置 JavaCPP 系统属性。有关更多详细信息,请参阅 JavaCPP 文档

使用快照

TensorFlow Java 源代码库中的最新 TensorFlow Java 开发快照可在 OSS Sonatype Nexus 存储库中获得。要依赖这些工件,请确保在您的 pom.xml 中配置 OSS 快照存储库。

<repositories>
    <repository>
        <id>tensorflow-snapshots</id>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>

<dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>0.4.0-SNAPSHOT</version>
    </dependency>
</dependencies>

使用 Gradle 安装

要将 TensorFlow 包含在您的 Gradle 应用程序中,请将对它的 工件 的依赖项添加到项目的 build.gradle 文件中。例如,

repositories {
    mavenCentral()
}

dependencies {
    compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.3'
}

减少依赖项数量

使用 Gradle 从 TensorFlow Java 中排除本机工件并不像使用 Maven 那样容易。我们建议您使用 Gradle JavaCPP 插件来减少依赖项数量。

有关更多详细信息,请参阅 Gradle JavaCPP 文档

从源代码安装

要从源代码构建 TensorFlow Java,并可能对其进行自定义,请阅读以下 说明

示例程序

此示例演示如何使用 TensorFlow 构建 Apache Maven 项目。首先,将 TensorFlow 依赖项添加到项目的 pom.xml 文件中

<project>
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.myorg</groupId>
    <artifactId>hellotensorflow</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <exec.mainClass>HelloTensorFlow</exec.mainClass>
        <!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>

    <dependencies>
        <!-- Include TensorFlow (pure CPU only) for all supported platforms -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.3.3</version>
        </dependency>
    </dependencies>
</project>

创建源文件 src/main/java/HelloTensorFlow.java

import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

  public static void main(String[] args) throws Exception {
    System.out.println("Hello TensorFlow " + TensorFlow.version());

    try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
        TInt32 x = TInt32.scalarOf(10);
        Tensor dblX = dbl.call(x)) {
      System.out.println(x.getInt() + " doubled is " + ((TInt32)dblX).getInt());
    }
  }

  private static Signature dbl(Ops tf) {
    Placeholder<TInt32> x = tf.placeholder(TInt32.class);
    Add<TInt32> dblX = tf.math.add(x, x);
    return Signature.builder().input("x", x).output("dbl", dblX).build();
  }
}

编译并执行

mvn -q compile exec:java

该命令打印 TensorFlow 版本和一个简单的计算结果。

成功!TensorFlow Java 已配置。