在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看 | 下载笔记本 | 查看 TF Hub 模型 |
此笔记本展示了如何使用 CropNet 木薯病害分类器 模型(来自 TensorFlow Hub)。该模型将木薯叶子的图像分类为 6 类之一:细菌性枯萎病、褐条病、绿螨、花叶病、健康或未知。
此 Colab 演示了如何
- 从 TensorFlow Hub 加载 https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 模型
- 从 TensorFlow Datasets (TFDS) 加载 木薯 数据集
- 将木薯叶子的图像分类为 4 种不同的木薯病害类别,或分类为健康或未知。
- 评估分类器的准确率,并查看将模型应用于域外图像时的鲁棒性。
导入和设置
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
用于显示示例的辅助函数
数据集
让我们从 TFDS 加载木薯数据集
dataset, info = tfds.load('cassava', with_info=True)
2024-03-09 13:44:07.128854: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
让我们看一下数据集信息,以了解有关它的更多信息,例如描述和引用,以及有关可用示例数量的信息
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_dir='gs://tensorflow-datasets/datasets/cassava/0.1.0', file_format=tfrecord, download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=uint8), 'image/filename': Text(shape=(), dtype=string), 'label': ClassLabel(shape=(), dtype=int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
木薯数据集包含木薯叶子的图像,这些图像具有 4 种不同的疾病,以及健康的木薯叶子。该模型可以预测所有这些类别,以及当模型对其预测没有信心时,第六类“未知”。
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
在将数据馈送到模型之前,我们需要进行一些预处理。该模型期望 224 x 224 的图像,RGB 通道值在 [0, 1] 之间。让我们对图像进行归一化和调整大小。
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
让我们看一下数据集中的几个示例
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
模型
让我们从 TF Hub 加载分类器,并获得一些预测,并查看模型对几个示例的预测
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
评估和鲁棒性
让我们在数据集的拆分上测量分类器的准确率。我们还可以查看模型的鲁棒性,方法是评估其在非木薯数据集上的性能。对于其他植物数据集(如 iNaturalist 或豆类)的图像,该模型几乎总是应该返回未知。
参数
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88 2024-03-09 13:44:27.693415: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
了解更多
- 在 TensorFlow Hub 上了解有关该模型的更多信息:https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- 了解如何使用 ML Kit 构建在移动电话上运行的自定义图像分类器,以及 此模型的 TensorFlow Lite 版本。