TensorFlow 2 初学者快速入门

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

这个简短的介绍使用 Keras

  1. 加载预构建的数据集。
  2. 构建一个对图像进行分类的神经网络机器学习模型。
  3. 训练这个神经网络。
  4. 评估模型的准确性。

本教程是一个 Google Colaboratory 笔记本。Python 程序直接在浏览器中运行——这是一种学习和使用 TensorFlow 的绝佳方式。要学习本教程,请通过点击此页面顶部的按钮在 Google Colab 中运行笔记本。

  1. 在 Colab 中,连接到 Python 运行时:在菜单栏的右上角,选择连接
  2. 要运行笔记本中的所有代码,请选择运行时 > 全部运行。要一次运行一个代码单元,请将鼠标悬停在每个单元格上,然后选择运行单元格图标。

Run cell icon

设置 TensorFlow

将 TensorFlow 导入您的程序以开始使用

import tensorflow as tf
print("TensorFlow version:", tf.__version__)
2024-07-13 06:52:18.998540: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-13 06:52:19.024905: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-13 06:52:19.024943: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow version: 2.16.2

如果您是在自己的开发环境中学习,而不是在 Colab 中,请查看 安装指南,了解如何为开发设置 TensorFlow。

加载数据集

加载并准备 MNIST 数据集。图像的像素值范围为 0 到 255。通过将值除以255.0,将这些值缩放到 0 到 1 的范围内。这也将样本数据从整数转换为浮点数。

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

构建机器学习模型

构建一个tf.keras.Sequential 模型

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)

Sequential 适用于堆叠层,其中每一层都有一个输入张量和一个输出张量。层是具有已知数学结构的函数,可以重复使用并具有可训练的变量。大多数 TensorFlow 模型都由层组成。此模型使用FlattenDenseDropout 层。

对于每个示例,模型返回一个logitslog-odds 分数向量,每个类一个。

predictions = model(x_train[:1]).numpy()
predictions
array([[-0.16857487, -0.01702446, -0.18800262, -0.29049426, -0.1497654 ,
         0.01135813, -0.13169414,  0.28695202,  0.37672728,  0.13688533]],
      dtype=float32)

tf.nn.softmax 函数将这些 logits 转换为每个类的概率

tf.nn.softmax(predictions).numpy()
array([[0.08373321, 0.09743506, 0.08212215, 0.07412229, 0.08532309,
        0.10024013, 0.086879  , 0.13204761, 0.14445063, 0.11364685]],
      dtype=float32)

使用losses.SparseCategoricalCrossentropy 定义用于训练的损失函数

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

损失函数接受一个地面实况值向量和一个 logits 向量,并为每个示例返回一个标量损失。此损失等于真实类的负对数概率:如果模型对正确类有把握,则损失为零。

此未经训练的模型给出的概率接近随机(每个类 1/10),因此初始损失应接近-tf.math.log(1/10) ~= 2.3

loss_fn(y_train[:1], predictions).numpy()
2.3001866

在开始训练之前,使用 KerasModel.compile 配置和编译模型。将optimizer 类设置为adam,将loss 设置为之前定义的loss_fn 函数,并通过将metrics 参数设置为accuracy 来指定要评估的模型指标。

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

训练和评估您的模型

使用Model.fit 方法调整模型参数并最小化损失

model.fit(x_train, y_train, epochs=5)
Epoch 1/5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720853545.131188  523089 service.cc:145] XLA service 0x7fb588008030 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1720853545.131230  523089 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1720853545.131234  523089 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1720853545.131237  523089 service.cc:153]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1720853545.131240  523089 service.cc:153]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
116/1875 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.5960 - loss: 1.3751
I0000 00:00:1720853546.834081  523089 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 1ms/step - accuracy: 0.8633 - loss: 0.4777
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.9561 - loss: 0.1461
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.9682 - loss: 0.1073
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.9744 - loss: 0.0848
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.9772 - loss: 0.0711
<keras.src.callbacks.history.History at 0x7fb750d4bc40>

Model.evaluate 方法检查模型的性能,通常在验证集测试集 上。

model.evaluate(x_test,  y_test, verbose=2)
313/313 - 1s - 3ms/step - accuracy: 0.9766 - loss: 0.0709
[0.07090681046247482, 0.9765999913215637]

图像分类器现在在这个数据集上训练到了约 98% 的准确率。要了解更多信息,请阅读TensorFlow 教程

如果您希望模型返回概率,可以包装训练后的模型,并将 softmax 附加到它

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model(x_test[:5])
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1.2273183e-07, 8.7016083e-09, 5.6712506e-06, 1.6725733e-04,
        3.8532035e-11, 3.1167556e-07, 3.1886015e-13, 9.9981803e-01,
        6.9700377e-08, 8.4526291e-06],
       [3.8318063e-10, 5.3196814e-05, 9.9994433e-01, 1.2887319e-06,
        6.4013896e-16, 1.0868150e-06, 1.7267263e-09, 9.1941645e-14,
        1.7351501e-07, 1.2004752e-16],
       [8.1174614e-09, 9.9988675e-01, 2.9579714e-05, 1.9076538e-06,
        5.8524206e-06, 1.8120238e-06, 8.6344735e-06, 3.9572973e-05,
        2.5875805e-05, 3.4755587e-08],
       [9.9995661e-01, 1.8822339e-09, 1.1518332e-06, 4.0551839e-08,
        1.0557898e-08, 3.4465447e-07, 2.8761485e-06, 3.4590674e-05,
        1.2326135e-09, 4.4350991e-06],
       [2.7767928e-07, 4.2751531e-09, 3.1395750e-06, 3.2061170e-08,
        9.9942112e-01, 1.6970212e-07, 1.6067910e-06, 8.6861517e-05,
        7.6640458e-07, 4.8607125e-04]], dtype=float32)>

结论

恭喜!您已使用Keras API 使用预构建数据集训练了机器学习模型。

有关使用 Keras 的更多示例,请查看教程。要了解有关使用 Keras 构建模型的更多信息,请阅读指南。如果您想了解有关加载和准备数据的更多信息,请查看有关图像数据加载CSV 数据加载 的教程。