在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 | 查看 TF Hub 模型 |
在本 Colab 笔记本中,您将学习如何使用 TensorFlow Lite Model Maker 训练自定义音频分类模型。
Model Maker 库使用迁移学习来简化使用自定义数据集训练 TensorFlow Lite 模型的过程。使用您自己的自定义数据集重新训练 TensorFlow Lite 模型可以减少所需的训练数据量和时间。
它是 自定义音频模型并在 Android 上部署的 Codelab 的一部分。
您将使用自定义鸟类数据集并导出一个可以在手机上使用的 TFLite 模型,一个可以在浏览器中用于推理的 TensorFlow.JS 模型,以及一个可以用于服务的 SavedModel 版本。
安装依赖项
sudo apt -y install libportaudio2
pip install tflite-model-maker
导入 TensorFlow、Model Maker 和其他库
在所需的依赖项中,您将使用 TensorFlow 和 Model Maker。除此之外,其他依赖项用于音频操作、播放和可视化。
import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import glob
import random
from IPython.display import Audio, Image
from scipy.io import wavfile
print(f"TensorFlow Version: {tf.__version__}")
print(f"Model Maker Version: {mm.__version__}")
鸟类数据集
鸟类数据集是一个包含 5 种鸟类鸣叫声的教育集合
- 白胸林莺
- 家麻雀
- 红交嘴雀
- 栗顶蚁鸫
- 阿扎拉的棘尾雀
原始音频来自 Xeno-canto,这是一个专门用于分享来自世界各地的鸟类声音的网站。
让我们从下载数据开始。
birds_dataset_folder = tf.keras.utils.get_file('birds_dataset.zip',
'https://storage.googleapis.com/laurencemoroney-blog.appspot.com/birds_dataset.zip',
cache_dir='./',
cache_subdir='dataset',
extract=True)
探索数据
音频已分为训练和测试文件夹。在每个拆分文件夹中,都有一个文件夹用于每种鸟类,使用它们的 bird_code
作为名称。
所有音频都是单声道,采样率为 16kHz。
有关每个文件的更多信息,您可以阅读 metadata.csv
文件。它包含所有文件的作者、许可证和更多信息。在本教程中,您无需自己阅读它。
# @title [Run this] Util functions and data structures.
data_dir = './dataset/small_birds_dataset'
bird_code_to_name = {
'wbwwre1': 'White-breasted Wood-Wren',
'houspa': 'House Sparrow',
'redcro': 'Red Crossbill',
'chcant2': 'Chestnut-crowned Antpitta',
'azaspi1': "Azara's Spinetail",
}
birds_images = {
'wbwwre1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/2/22/Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg/640px-Henicorhina_leucosticta_%28Cucarachero_pechiblanco%29_-_Juvenil_%2814037225664%29.jpg', # Alejandro Bayer Tamayo from Armenia, Colombia
'houspa': 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/52/House_Sparrow%2C_England_-_May_09.jpg/571px-House_Sparrow%2C_England_-_May_09.jpg', # Diliff
'redcro': 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Red_Crossbills_%28Male%29.jpg/640px-Red_Crossbills_%28Male%29.jpg', # Elaine R. Wilson, www.naturespicsonline.com
'chcant2': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Chestnut-crowned_antpitta_%2846933264335%29.jpg/640px-Chestnut-crowned_antpitta_%2846933264335%29.jpg', # Mike's Birds from Riverside, CA, US
'azaspi1': 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Synallaxis_azarae_76608368.jpg/640px-Synallaxis_azarae_76608368.jpg', # https://www.inaturalist.org/photos/76608368
}
test_files = os.path.abspath(os.path.join(data_dir, 'test/*/*.wav'))
def get_random_audio_file():
test_list = glob.glob(test_files)
random_audio_path = random.choice(test_list)
return random_audio_path
def show_bird_data(audio_path):
sample_rate, audio_data = wavfile.read(audio_path, 'rb')
bird_code = audio_path.split('/')[-2]
print(f'Bird name: {bird_code_to_name[bird_code]}')
print(f'Bird code: {bird_code}')
display(Image(birds_images[bird_code]))
plttitle = f'{bird_code_to_name[bird_code]} ({bird_code})'
plt.title(plttitle)
plt.plot(audio_data)
display(Audio(audio_data, rate=sample_rate))
print('functions and data structures created')
播放一些音频
为了更好地了解数据,让我们听一下测试拆分中的随机音频文件。
random_audio = get_random_audio_file()
show_bird_data(random_audio)
训练模型
使用 Model Maker 进行音频处理时,您需要从模型规范开始。这是您新模型将从中提取信息以学习新类别的基础模型。它还会影响数据集如何进行转换以符合模型规范参数,例如:采样率、声道数。
YAMNet 是一种音频事件分类器,它在 AudioSet 数据集上进行了训练,可以根据 AudioSet 本体预测音频事件。
它的输入预期为 16kHz,且具有 1 个声道。
您无需自己进行任何重采样。Model Maker 会为您处理。
frame_length
用于决定每个训练样本的长度。在本例中,EXPECTED_WAVEFORM_LENGTH * 3sframe_steps
用于决定训练样本之间的间隔。在本例中,第 i 个样本将从第 (i-1) 个样本后的 EXPECTED_WAVEFORM_LENGTH * 6s 开始。
设置这些值的原因是为了解决现实世界数据集中的某些限制。
例如,在鸟类数据集中,鸟类并不总是唱歌。它们会唱歌、休息,然后再次唱歌,期间会有噪音。较长的帧有助于捕获鸟鸣声,但设置过长会减少训练样本数量。
spec = audio_classifier.YamNetSpec(
keep_yamnet_and_custom_heads=True,
frame_step=3 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
frame_length=6 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)
加载数据
Model Maker 具有 API,可以从文件夹加载数据,并将其以模型规范所需的格式保存。
训练集和测试集的划分基于文件夹。验证数据集将作为训练集的 20% 创建。
train_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'test'), cache=True)
训练模型
audio_classifier 具有 create
方法,该方法可以创建模型并立即开始训练。
您可以自定义许多参数,有关更多信息,请阅读文档中的更多详细信息。
在第一次尝试中,您将使用所有默认配置,并训练 100 个 epoch。
batch_size = 128
epochs = 100
print('Training the model')
model = audio_classifier.create(
train_data,
spec,
validation_data,
batch_size=batch_size,
epochs=epochs)
准确率看起来不错,但重要的是在测试数据上运行评估步骤,并验证您的模型在未见数据上取得了良好的结果。
print('Evaluating the model')
model.evaluate(test_data)
了解您的模型
在训练分类器时,查看 混淆矩阵 很有用。混淆矩阵可以详细了解您的分类器在测试数据上的表现。
Model Maker 已经为您创建了混淆矩阵。
def show_confusion_matrix(confusion, test_labels):
"""Compute confusion matrix and normalize."""
confusion_normalized = confusion.astype("float") / confusion.sum(axis=1)
axis_labels = test_labels
ax = sns.heatmap(
confusion_normalized, xticklabels=axis_labels, yticklabels=axis_labels,
cmap='Blues', annot=True, fmt='.2f', square=True)
plt.title("Confusion matrix")
plt.ylabel("True label")
plt.xlabel("Predicted label")
confusion_matrix = model.confusion_matrix(test_data)
show_confusion_matrix(confusion_matrix.numpy(), test_data.index_to_label)
测试模型 [可选]
您可以尝试在测试数据集中的样本音频上使用模型,以查看结果。
首先,您需要获取服务模型。
serving_model = model.create_serving_model()
print(f'Model\'s input shape and type: {serving_model.inputs}')
print(f'Model\'s output shape and type: {serving_model.outputs}')
回到您之前加载的随机音频
# if you want to try another file just uncoment the line below
random_audio = get_random_audio_file()
show_bird_data(random_audio)
创建的模型具有固定的输入窗口。
对于给定的音频文件,您需要将其分割成预期大小的数据窗口。最后一个窗口可能需要用零填充。
sample_rate, audio_data = wavfile.read(random_audio, 'rb')
audio_data = np.array(audio_data) / tf.int16.max
input_size = serving_model.input_shape[1]
splitted_audio_data = tf.signal.frame(audio_data, input_size, input_size, pad_end=True, pad_value=0)
print(f'Test audio path: {random_audio}')
print(f'Original size of the audio data: {len(audio_data)}')
print(f'Number of windows for inference: {len(splitted_audio_data)}')
您将循环遍历所有分割的音频,并对每个音频应用模型。
您刚刚训练的模型有两个输出:原始 YAMNet 的输出和您刚刚训练的输出。这很重要,因为现实世界环境比仅仅是鸟鸣声要复杂得多。您可以使用 YAMNet 的输出来过滤掉不相关的音频,例如,在鸟类用例中,如果 YAMNet 没有将音频分类为鸟类或动物,这可能表明您的模型的输出可能具有不相关的分类。
下面打印了两个输出,以便更容易理解它们之间的关系。您的模型犯的大多数错误都是当 YAMNet 的预测与您的领域(例如:鸟类)无关时。
print(random_audio)
results = []
print('Result of the window ith: your model class -> score, (spec class -> score)')
for i, data in enumerate(splitted_audio_data):
yamnet_output, inference = serving_model(data)
results.append(inference[0].numpy())
result_index = tf.argmax(inference[0])
spec_result_index = tf.argmax(yamnet_output[0])
t = spec._yamnet_labels()[spec_result_index]
result_str = f'Result of the window {i}: ' \
f'\t{test_data.index_to_label[result_index]} -> {inference[0][result_index].numpy():.3f}, ' \
f'\t({spec._yamnet_labels()[spec_result_index]} -> {yamnet_output[0][spec_result_index]:.3f})'
print(result_str)
results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'Mean result: {test_data.index_to_label[result_index]} -> {mean_results[result_index]}')
导出模型
最后一步是导出您的模型,以便在嵌入式设备或浏览器上使用。
export
方法为您导出两种格式。
models_path = './birds_models'
print(f'Exporing the TFLite model to {models_path}')
model.export(models_path, tflite_filename='my_birds_model.tflite')
您还可以导出 SavedModel 版本,以便在 Python 环境中进行服务或使用。
model.export(models_path, export_format=[mm.ExportFormat.SAVED_MODEL, mm.ExportFormat.LABEL])
下一步
您做到了。
现在,您可以使用 TFLite AudioClassifier 任务 API 将您的新模型部署到移动设备上。
您也可以尝试使用自己的数据和不同的类别进行相同的过程,以下是 Model Maker for Audio Classification 的文档。