TensorFlow Java 可以在任何 JVM 上运行,用于构建、训练和部署机器学习模型。它支持 CPU 和 GPU 执行,以图形或急切模式运行,并提供丰富的 API,以便在 JVM 环境中使用 TensorFlow。Java 和其他 JVM 语言(如 Scala 和 Kotlin)在全球大小企业中被广泛使用,这使得 TensorFlow Java 成为大规模采用机器学习的战略选择。
要求
TensorFlow Java 运行在 Java 8 及更高版本上,并开箱即用地支持以下平台
- Ubuntu 16.04 或更高版本;64 位,x86
- macOS 10.12.6 (Sierra) 或更高版本;64 位,x86
- Windows 7 或更高版本;64 位,x86
版本
TensorFlow Java 具有自己的发布周期,独立于 TensorFlow 运行时。因此,它的版本与它运行的 TensorFlow 运行时版本不匹配。请参阅 TensorFlow Java 版本表 以列出所有可用版本及其与 TensorFlow 运行时的映射。
工件
有 几种方法 可以将 TensorFlow Java 添加到您的项目中。最简单的方法是添加对 tensorflow-core-platform
工件的依赖项,其中包括 TensorFlow Java 核心 API 及其在所有支持平台上运行所需的本机依赖项。
您还可以选择以下扩展之一,而不是纯 CPU 版本
tensorflow-core-platform-mkl
:在所有平台上支持 Intel® MKL-DNNtensorflow-core-platform-gpu
:在 Linux 和 Windows 平台上支持 CUDA®tensorflow-core-platform-mkl-gpu
:在 Linux 平台上支持 Intel® MKL-DNN 和 CUDA®。
此外,可以添加对 tensorflow-framework
库的单独依赖项,以从 JVM 上基于 TensorFlow 的机器学习的丰富实用程序集受益。
使用 Maven 安装
要将 TensorFlow 包含在您的 Maven 应用程序中,请将对它的 工件 的依赖项添加到项目的 pom.xml
文件中。例如,
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
减少依赖项数量
重要的是要注意,添加对 tensorflow-core-platform
工件的依赖项将导入所有支持平台的本机库,这会显着增加项目的规模。
如果您希望针对可用平台的子集,则可以使用 Maven 依赖项排除 功能从其他平台中排除不必要的工件。
另一种选择要将哪些平台包含在应用程序中的方法是在 Maven 命令行或 pom.xml
中设置 JavaCPP 系统属性。有关更多详细信息,请参阅 JavaCPP 文档。
使用快照
TensorFlow Java 源代码库中的最新 TensorFlow Java 开发快照可在 OSS Sonatype Nexus 存储库中获得。要依赖这些工件,请确保在您的 pom.xml
中配置 OSS 快照存储库。
<repositories>
<repository>
<id>tensorflow-snapshots</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
<snapshots>
<enabled>true</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.4.0-SNAPSHOT</version>
</dependency>
</dependencies>
使用 Gradle 安装
要将 TensorFlow 包含在您的 Gradle 应用程序中,请将对它的 工件 的依赖项添加到项目的 build.gradle
文件中。例如,
repositories {
mavenCentral()
}
dependencies {
compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.3'
}
减少依赖项数量
使用 Gradle 从 TensorFlow Java 中排除本机工件并不像使用 Maven 那样容易。我们建议您使用 Gradle JavaCPP 插件来减少依赖项数量。
有关更多详细信息,请参阅 Gradle JavaCPP 文档。
从源代码安装
要从源代码构建 TensorFlow Java,并可能对其进行自定义,请阅读以下 说明。
示例程序
此示例演示如何使用 TensorFlow 构建 Apache Maven 项目。首先,将 TensorFlow 依赖项添加到项目的 pom.xml
文件中
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>hellotensorflow</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>HelloTensorFlow</exec.mainClass>
<!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<!-- Include TensorFlow (pure CPU only) for all supported platforms -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.3.3</version>
</dependency>
</dependencies>
</project>
创建源文件 src/main/java/HelloTensorFlow.java
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;
public class HelloTensorFlow {
public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
TInt32 x = TInt32.scalarOf(10);
Tensor dblX = dbl.call(x)) {
System.out.println(x.getInt() + " doubled is " + ((TInt32)dblX).getInt());
}
}
private static Signature dbl(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Add<TInt32> dblX = tf.math.add(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
}
编译并执行
mvn -q compile exec:java
该命令打印 TensorFlow 版本和一个简单的计算结果。
成功!TensorFlow Java 已配置。