在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
本教程演示了如何在高度不平衡的数据集中进行分类,其中一个类别的示例数量远远超过另一个类别的示例数量。您将使用在 Kaggle 上托管的 信用卡欺诈检测 数据集。目标是从总共 284,807 笔交易中检测出仅仅 492 笔欺诈性交易。您将使用 Keras 定义模型,并使用 类别权重 帮助模型从不平衡数据中学习。.
本教程包含完成以下操作的完整代码
- 使用 Pandas 加载 CSV 文件。
- 创建训练集、验证集和测试集。
- 使用 Keras 定义和训练模型(包括设置类别权重)。
- 使用各种指标(包括精确率和召回率)评估模型。
- 为概率分类器选择一个阈值以获得确定性分类器。
- 尝试并比较类别加权建模和过采样。
设置
import tensorflow as tf
from tensorflow import keras
import os
import tempfile
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2024-01-17 02:20:29.309180: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-17 02:20:29.309224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-17 02:20:29.310677: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
数据处理和探索
下载 Kaggle 信用卡欺诈数据集
Pandas 是一个 Python 库,它提供了许多用于加载和处理结构化数据的实用程序。它可以用于将 CSV 下载到 Pandas DataFrame 中。
file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()
检查类别标签不平衡
让我们看看数据集的不平衡
neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format(
total, pos, 100 * pos / total))
Examples: Total: 284807 Positive: 492 (0.17% of total)
这显示了正样本的小比例。
清理、拆分和规范化数据
原始数据存在一些问题。首先,Time
和 Amount
列变化太大,无法直接使用。删除 Time
列(因为其含义尚不清楚),并对 Amount
列取对数以缩小其范围。
cleaned_df = raw_df.copy()
# You don't want the `Time` column.
cleaned_df.pop('Time')
# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Amount'] = np.log(cleaned_df.pop('Amount')+eps)
将数据集拆分为训练集、验证集和测试集。验证集在模型拟合过程中用于评估损失和任何指标,但模型不会使用此数据进行拟合。测试集在训练阶段完全未使用,仅在最后用于评估模型对新数据的泛化能力。这对于不平衡数据集尤其重要,因为在不平衡数据集上,由于缺乏训练数据,过拟合 是一个重大问题。
# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)
# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))
train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)
我们检查三个集合中类别分布是否大致相同。
print(f'Average class probability in training set: {train_labels.mean():.4f}')
print(f'Average class probability in validation set: {val_labels.mean():.4f}')
print(f'Average class probability in test set: {test_labels.mean():.4f}')
Average class probability in training set: 0.0016 Average class probability in validation set: 0.0018 Average class probability in test set: 0.0019
鉴于正标签数量很少,这似乎是合理的。
使用 sklearn 的 StandardScaler 规范化输入特征。这将使均值变为 0,标准差变为 1。
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)
train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)
print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)
print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,) Validation labels shape: (45569,) Test labels shape: (56962,) Training features shape: (182276, 29) Validation features shape: (45569, 29) Test features shape: (56962, 29)
查看数据分布
接下来,比较一些特征上正例和负例的分布。此时,你需要问自己一些问题
- 这些分布有意义吗?
- 是的。你已经规范化了输入,并且这些输入主要集中在
+/- 2
范围内。
- 是的。你已经规范化了输入,并且这些输入主要集中在
- 你能看到分布之间的差异吗?
- 是的,正例包含更高的极值率。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)
sns.jointplot(x=pos_df['V5'], y=pos_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")
sns.jointplot(x=neg_df['V5'], y=neg_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
定义模型和指标
定义一个函数,该函数创建一个简单的具有密集连接隐藏层、dropout 层(用于减少过拟合)和输出 sigmoid 层的神经网络,该层返回交易为欺诈的概率。
METRICS = [
keras.metrics.BinaryCrossentropy(name='cross entropy'), # same as model's loss
keras.metrics.MeanSquaredError(name='Brier score'),
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='accuracy'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]
def make_model(metrics=METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = keras.Sequential([
keras.layers.Dense(
16, activation='relu',
input_shape=(train_features.shape[-1],)),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=metrics)
return model
理解有用的指标
请注意,上面定义了一些指标,这些指标可以通过模型计算,在评估性能时将很有帮助。这些指标可以分为三组。
概率预测指标
当我们使用交叉熵作为损失函数训练网络时,它完全能够预测类别概率,即它是一个概率分类器。事实上,评估概率预测的良好指标是适当评分规则。它们的關鍵特性是预测真实概率是最优的。我们给出两个众所周知的例子
- 交叉熵也称为对数损失
- 均方误差也称为 Brier 分数
确定性 0/1 预测指标
最终,人们通常希望预测一个类别标签,0 或 1,无欺诈或欺诈。这被称为确定性分类器。为了从我们的概率分类器中获得标签预测,需要选择一个概率阈值 \(t\)。默认情况下,如果预测概率大于 \(t=50\%\),则预测标签 1(欺诈),以下所有指标都隐式使用此默认值。
- 错误的负例和错误的正例是错误分类的样本
- 正确的负例和正确的正例是正确分类的样本
- 准确率是正确分类的样本百分比 > \(\frac{\text{正确样本} }{\text{总样本} }\)
- 精确率是预测为正例且分类正确的样本百分比 > \(\frac{\text{正确正例} }{\text{正确正例 + 错误正例} }\)
- 召回率是实际为正例且分类正确的样本百分比 > \(\frac{\text{正确正例} }{\text{正确正例 + 错误负例} }\)
其他指标
以下指标考虑了所有可能的阈值 \(t\) 选择。
- AUC 指的是接收者操作特征曲线 (ROC-AUC) 下的面积。此指标等于分类器将随机正样本排名高于随机负样本的概率。
- AUPRC 指的是精确率-召回率曲线下的面积。此指标计算不同概率阈值的精确率-召回率对。
阅读更多
基线模型
构建模型
现在使用之前定义的函数创建和训练你的模型。请注意,模型使用大于默认批次大小的 2048 进行拟合,这对于确保每个批次都有一个合理的概率包含一些正样本非常重要。如果批次大小太小,它们可能没有欺诈交易可供学习。
EPOCHS = 100
BATCH_SIZE = 2048
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_prc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 480 dropout (Dropout) (None, 16) 0 dense_1 (Dense) (None, 1) 17 ================================================================= Total params: 497 (1.94 KB) Trainable params: 497 (1.94 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
测试运行模型
model.predict(train_features[:10])
1/1 [==============================] - 0s 471ms/step array([[0.16263928], [0.35204744], [0.19377157], [0.72603256], [0.30116165], [0.25605297], [0.66053736], [0.31973222], [0.25077152], [0.26151225]], dtype=float32)
可选:设置正确的初始偏差。
这些初始猜测并不理想。你知道数据集是不平衡的。设置输出层的偏差以反映这一点,请参阅 训练神经网络的秘诀:“初始化良好”。这有助于初始收敛。
使用默认偏差初始化,损失应该约为 math.log(2) = 0.69314
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.4088
可以从以下公式推导出要设置的正确偏差
\[ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) \]
\[ b_0 = -log_e(1/p_0 - 1) \]
\[ b_0 = log_e(pos/neg)\]
initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])
将其设置为初始偏差,模型将给出更合理的初始猜测。
它应该接近:pos/total = 0.0018
model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
1/1 [==============================] - 0s 75ms/step array([[0.00135984], [0.00134607], [0.00213977], [0.01406598], [0.0021732 ], [0.00640495], [0.00814889], [0.00254694], [0.00572464], [0.00216844]], dtype=float32)
使用此初始化,初始损失应该约为
\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0087
此初始损失大约是使用朴素初始化时的 50 倍。
这样,模型就不需要花费前几个 epoch 来学习正例不太可能出现。它也使在训练期间更容易阅读损失图。
检查点初始权重
为了使各种训练运行更具可比性,请将此初始模型的权重保存在检查点文件中,并在训练之前将它们加载到每个模型中。
initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)
确认偏差修复是否有帮助
在继续之前,请快速确认仔细的偏差初始化是否真的有帮助。
使用和不使用这种仔细的初始化训练模型 20 个 epoch,并比较损失
model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1705458046.535087 10301 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=20,
validation_data=(val_features, val_labels),
verbose=0)
def plot_loss(history, label, n):
# Use a log scale on y-axis to show the wide range of values.
plt.semilogy(history.epoch, history.history['loss'],
color=colors[n], label='Train ' + label)
plt.semilogy(history.epoch, history.history['val_loss'],
color=colors[n], label='Val ' + label,
linestyle="--")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)
上图清楚地表明:就验证损失而言,在这个问题上,这种仔细的初始化具有明显的优势。
训练模型
model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels))
Epoch 1/100 90/90 [==============================] - 2s 11ms/step - loss: 0.0109 - cross entropy: 0.0092 - Brier score: 0.0013 - tp: 179.0000 - fp: 128.0000 - tn: 227336.0000 - fn: 202.0000 - accuracy: 0.9986 - precision: 0.5831 - recall: 0.4698 - auc: 0.8759 - prc: 0.4240 - val_loss: 0.0053 - val_cross entropy: 0.0053 - val_Brier score: 7.6563e-04 - val_tp: 44.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 38.0000 - val_accuracy: 0.9991 - val_precision: 0.8980 - val_recall: 0.5366 - val_auc: 0.9188 - val_prc: 0.7535 Epoch 2/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0070 - cross entropy: 0.0070 - Brier score: 9.7767e-04 - tp: 137.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8155 - recall: 0.4582 - auc: 0.8800 - prc: 0.5341 - val_loss: 0.0044 - val_cross entropy: 0.0044 - val_Brier score: 6.3545e-04 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 28.0000 - val_accuracy: 0.9992 - val_precision: 0.8852 - val_recall: 0.6585 - val_auc: 0.9263 - val_prc: 0.7737 Epoch 3/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0061 - cross entropy: 0.0061 - Brier score: 8.9642e-04 - tp: 146.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8156 - recall: 0.4883 - auc: 0.9033 - prc: 0.5771 - val_loss: 0.0040 - val_cross entropy: 0.0040 - val_Brier score: 6.0828e-04 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 27.0000 - val_accuracy: 0.9993 - val_precision: 0.9016 - val_recall: 0.6707 - val_auc: 0.9266 - val_prc: 0.7869 Epoch 4/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0057 - cross entropy: 0.0057 - Brier score: 8.9241e-04 - tp: 147.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 152.0000 - accuracy: 0.9990 - precision: 0.8305 - recall: 0.4916 - auc: 0.9045 - prc: 0.6121 - val_loss: 0.0037 - val_cross entropy: 0.0037 - val_Brier score: 5.6512e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.7996 Epoch 5/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0050 - cross entropy: 0.0050 - Brier score: 8.0944e-04 - tp: 163.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8402 - recall: 0.5452 - auc: 0.9091 - prc: 0.6557 - val_loss: 0.0035 - val_cross entropy: 0.0035 - val_Brier score: 5.4862e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.8041 Epoch 6/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0046 - cross entropy: 0.0046 - Brier score: 7.5796e-04 - tp: 168.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8615 - recall: 0.5619 - auc: 0.9214 - prc: 0.6995 - val_loss: 0.0034 - val_cross entropy: 0.0034 - val_Brier score: 5.3008e-04 - val_tp: 60.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8955 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8149 Epoch 7/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0045 - cross entropy: 0.0045 - Brier score: 7.0728e-04 - tp: 183.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8472 - recall: 0.6120 - auc: 0.9133 - prc: 0.6901 - val_loss: 0.0033 - val_cross entropy: 0.0033 - val_Brier score: 5.3596e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8227 Epoch 8/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0048 - cross entropy: 0.0048 - Brier score: 8.0575e-04 - tp: 169.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8284 - recall: 0.5652 - auc: 0.9183 - prc: 0.6610 - val_loss: 0.0032 - val_cross entropy: 0.0032 - val_Brier score: 5.4781e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9389 - val_prc: 0.8321 Epoch 9/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4602e-04 - tp: 170.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8629 - recall: 0.5686 - auc: 0.9186 - prc: 0.7075 - val_loss: 0.0031 - val_cross entropy: 0.0031 - val_Brier score: 5.1218e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8314 Epoch 10/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 6.7102e-04 - tp: 178.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 121.0000 - accuracy: 0.9992 - precision: 0.8768 - recall: 0.5953 - auc: 0.9203 - prc: 0.7351 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8812e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8293 Epoch 11/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.3325e-04 - tp: 191.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 108.0000 - accuracy: 0.9993 - precision: 0.8843 - recall: 0.6388 - auc: 0.9170 - prc: 0.7323 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8228e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8301 Epoch 12/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.6081e-04 - tp: 173.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8317 - recall: 0.5786 - auc: 0.9254 - prc: 0.7097 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7943e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8330 Epoch 13/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4700e-04 - tp: 175.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.5853 - auc: 0.9238 - prc: 0.6897 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7884e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8350 Epoch 14/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 7.1931e-04 - tp: 177.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 122.0000 - accuracy: 0.9992 - precision: 0.8551 - recall: 0.5920 - auc: 0.9171 - prc: 0.7144 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.8724e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8409 Epoch 15/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.5652e-04 - tp: 167.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 132.0000 - accuracy: 0.9991 - precision: 0.8608 - recall: 0.5585 - auc: 0.9238 - prc: 0.6964 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7200e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8410 Epoch 16/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.1767e-04 - tp: 177.0000 - fp: 34.0000 - tn: 181943.0000 - fn: 122.0000 - accuracy: 0.9991 - precision: 0.8389 - recall: 0.5920 - auc: 0.9239 - prc: 0.7223 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6891e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8418 Epoch 17/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0041 - cross entropy: 0.0041 - Brier score: 7.5757e-04 - tp: 166.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8601 - recall: 0.5552 - auc: 0.9255 - prc: 0.7017 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6881e-04 - val_tp: 64.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9143 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8419 Epoch 18/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.7869e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7328 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.5812e-04 - val_tp: 67.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9178 - val_recall: 0.8171 - val_auc: 0.9449 - val_prc: 0.8473 Epoch 19/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 6.9306e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9222 - prc: 0.7129 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9389 - val_prc: 0.8439 Epoch 20/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5706e-04 - tp: 191.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 108.0000 - accuracy: 0.9992 - precision: 0.8604 - recall: 0.6388 - auc: 0.9240 - prc: 0.7368 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8355e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8442 Epoch 21/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.2760e-04 - tp: 180.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8451 - recall: 0.6020 - auc: 0.9223 - prc: 0.7170 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.9822e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9388 - val_prc: 0.8441 Epoch 22/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.5225e-04 - tp: 181.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 118.0000 - accuracy: 0.9992 - precision: 0.8660 - recall: 0.6054 - auc: 0.9273 - prc: 0.7439 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8637e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8438 Epoch 23/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5348e-04 - tp: 186.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8774 - recall: 0.6221 - auc: 0.9355 - prc: 0.7402 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6483e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8427 Epoch 24/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.7939e-04 - tp: 193.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8465 - recall: 0.6455 - auc: 0.9340 - prc: 0.7279 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1275e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8509 Epoch 25/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.7560e-04 - tp: 180.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8654 - recall: 0.6020 - auc: 0.9290 - prc: 0.7396 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8990e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8503 Epoch 26/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.4978e-04 - tp: 188.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 111.0000 - accuracy: 0.9992 - precision: 0.8704 - recall: 0.6288 - auc: 0.9307 - prc: 0.7594 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7567e-04 - val_tp: 63.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9265 - val_recall: 0.7683 - val_auc: 0.9388 - val_prc: 0.8439 Epoch 27/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.1788e-04 - tp: 183.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8632 - recall: 0.6120 - auc: 0.9289 - prc: 0.7194 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6391e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8454 Epoch 28/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.2664e-04 - tp: 200.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 99.0000 - accuracy: 0.9993 - precision: 0.8734 - recall: 0.6689 - auc: 0.9306 - prc: 0.7426 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1824e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8435 Epoch 29/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.9012e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7251 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 60.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9231 - val_recall: 0.7317 - val_auc: 0.9449 - val_prc: 0.8510 Epoch 30/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.0173e-04 - tp: 197.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 102.0000 - accuracy: 0.9993 - precision: 0.8874 - recall: 0.6589 - auc: 0.9290 - prc: 0.7333 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8578e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8570 Epoch 31/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.9174e-04 - tp: 187.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8423 - recall: 0.6254 - auc: 0.9373 - prc: 0.7320 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.2550e-04 - val_tp: 58.0000 - val_fp: 3.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9508 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8546 Epoch 32/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6040e-04 - tp: 183.0000 - fp: 21.0000 - tn: 181956.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8971 - recall: 0.6120 - auc: 0.9356 - prc: 0.7430 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9123e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8581 Epoch 33/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.3240e-04 - tp: 198.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 101.0000 - accuracy: 0.9993 - precision: 0.8800 - recall: 0.6622 - auc: 0.9339 - prc: 0.7473 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 5.0220e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8544 Epoch 34/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.3294e-04 - tp: 193.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8694 - recall: 0.6455 - auc: 0.9373 - prc: 0.7536 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8504e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8489 Epoch 35/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.8906e-04 - tp: 184.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8638 - recall: 0.6154 - auc: 0.9239 - prc: 0.7403 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9829e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8482 Epoch 36/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.5897e-04 - tp: 193.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8655 - recall: 0.6455 - auc: 0.9340 - prc: 0.7307 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9601e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8474 Epoch 37/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.0205e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9373 - prc: 0.7124 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9088e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8476 Epoch 38/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4066e-04 - tp: 195.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8824 - recall: 0.6522 - auc: 0.9374 - prc: 0.7656 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8218e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8483 Epoch 39/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.2081e-04 - tp: 195.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8744 - recall: 0.6522 - auc: 0.9274 - prc: 0.7673 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7334e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8511 Epoch 40/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.0397e-04 - tp: 202.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 97.0000 - accuracy: 0.9993 - precision: 0.8783 - recall: 0.6756 - auc: 0.9358 - prc: 0.7739 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7153e-04 - val_tp: 62.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 20.0000 - val_accuracy: 0.9995 - val_precision: 0.9394 - val_recall: 0.7561 - val_auc: 0.9449 - val_prc: 0.8499 Epoch 41/100 90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6866e-04 - tp: 186.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8815 - recall: 0.6221 - auc: 0.9407 - prc: 0.7539 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6169e-04 - val_tp: 66.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9296 - val_recall: 0.8049 - val_auc: 0.9510 - val_prc: 0.8571 Epoch 42/100 86/90 [===========================>..] - ETA: 0s - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4029e-04 - tp: 188.0000 - fp: 30.0000 - tn: 175806.0000 - fn: 104.0000 - accuracy: 0.9992 - precision: 0.8624 - recall: 0.6438 - auc: 0.9445 - prc: 0.7663Restoring model weights from the end of the best epoch: 32. 90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.3642e-04 - tp: 193.0000 - fp: 32.0000 - tn: 181945.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.6455 - auc: 0.9441 - prc: 0.7655 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7751e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8563 Epoch 42: early stopping
检查训练历史
在本节中,你将生成模型在训练集和验证集上的准确率和损失图。这些图有助于检查过拟合,你可以在 过拟合和欠拟合 教程中了解更多关于过拟合的信息。
此外,你可以为上面创建的任何指标生成这些图。错误的负例作为示例包含在内。
def plot_metrics(history):
metrics = ['loss', 'prc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
plt.plot(history.epoch, history.history['val_'+metric],
color=colors[0], linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1])
else:
plt.ylim([0,1])
plt.legend()
plot_metrics(baseline_history)
评估指标
你可以使用 混淆矩阵 来总结实际标签与预测标签,其中 X 轴是预测标签,Y 轴是实际标签
train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step 28/28 [==============================] - 0s 1ms/step
def plot_cm(labels, predictions, threshold=0.5):
cm = confusion_matrix(labels, predictions > threshold)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(threshold))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))
在测试数据集上评估你的模型,并显示上面创建的指标的结果
baseline_results = model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_baseline)
loss : 0.0038855739403516054 cross entropy : 0.0038855739403516054 Brier score : 0.0006162827485240996 tp : 81.0 fp : 11.0 tn : 56840.0 fn : 30.0 accuracy : 0.9992802143096924 precision : 0.8804348111152649 recall : 0.7297297120094299 auc : 0.9096326231956482 prc : 0.7863917350769043 Legitimate Transactions Detected (True Negatives): 56840 Legitimate Transactions Incorrectly Detected (False Positives): 11 Fraudulent Transactions Missed (False Negatives): 30 Fraudulent Transactions Detected (True Positives): 81 Total Fraudulent Transactions: 111
如果模型完美地预测了所有内容(真正的随机性是不可能的),这将是一个 对角矩阵,其中主对角线以外的值(表示错误预测)将为零。在这种情况下,矩阵显示你只有相对较少的错误正例,这意味着只有相对较少的合法交易被错误地标记为欺诈。
更改阈值
\(t=50\%\) 的默认阈值对应于错误负例和错误正例的成本相等。然而,在欺诈检测的情况下,你可能会将更高的成本与错误负例相关联,而不是与错误正例相关联。这种权衡可能是可取的,因为错误负例将允许欺诈交易通过,而错误正例可能会导致向客户发送电子邮件,要求他们验证其卡活动。
通过降低阈值,我们将更高的成本归因于错误负例,从而以更多错误正例为代价增加了遗漏的交易。我们测试了 10% 和 1% 的阈值。
plot_cm(test_labels, test_predictions_baseline, threshold=0.1)
plot_cm(test_labels, test_predictions_baseline, threshold=0.01)
Legitimate Transactions Detected (True Negatives): 56834 Legitimate Transactions Incorrectly Detected (False Positives): 17 Fraudulent Transactions Missed (False Negatives): 23 Fraudulent Transactions Detected (True Positives): 88 Total Fraudulent Transactions: 111 Legitimate Transactions Detected (True Negatives): 56806 Legitimate Transactions Incorrectly Detected (False Positives): 45 Fraudulent Transactions Missed (False Negatives): 22 Fraudulent Transactions Detected (True Positives): 89 Total Fraudulent Transactions: 111
绘制 ROC
现在绘制 ROC。此图很有用,因为它一目了然地显示了模型可以通过在整个范围内(0 到 1)调整输出阈值来达到的性能范围。因此,每个点对应于阈值的单个值。
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,20])
plt.ylim([80,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');
绘制 PRC
现在绘制 AUPRC。通过绘制不同分类阈值的 (召回率,精确率) 点获得的插值精确率-召回率曲线下的面积。根据计算方式,PR AUC 可能等效于模型的平均精确率。
def plot_prc(name, labels, predictions, **kwargs):
precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)
plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
plt.xlabel('Precision')
plt.ylabel('Recall')
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');
看起来精确率相对较高,但召回率和 ROC 曲线下的面积 (AUC) 并不像你希望的那样高。分类器在尝试最大化精确率和召回率时经常面临挑战,这在处理不平衡数据集时尤其如此。重要的是要考虑你所关心的问题的背景下不同类型错误的成本。在本例中,错误负例(遗漏欺诈交易)可能会有财务成本,而错误正例(交易被错误地标记为欺诈)可能会降低用户满意度。
类别权重
计算类别权重
目标是识别欺诈交易,但你没有太多正样本可用,因此你需要让分类器对少数可用样本给予更高的权重。你可以通过参数传递 Keras 权重来实现这一点,每个类别对应一个权重。这将使模型“更加关注”来自样本不足类别的样本。但是,请注意,这不会以任何方式增加数据集的信息量。最终,使用类别权重或多或少等同于改变输出偏差或改变阈值。让我们看看它是如何运作的。
# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)
class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50 Weight for class 1: 289.44
使用类别权重训练模型
现在尝试使用类别权重重新训练和评估模型,看看它如何影响预测。
weighted_model = make_model()
weighted_model.load_weights(initial_weights)
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
Epoch 1/100 90/90 [==============================] - 2s 11ms/step - loss: 0.9262 - cross entropy: 0.0166 - Brier score: 0.0027 - tp: 233.0000 - fp: 530.0000 - tn: 238298.0000 - fn: 177.0000 - accuracy: 0.9970 - precision: 0.3054 - recall: 0.5683 - auc: 0.8803 - prc: 0.4222 - val_loss: 0.0116 - val_cross entropy: 0.0116 - val_Brier score: 0.0011 - val_tp: 67.0000 - val_fp: 35.0000 - val_tn: 45452.0000 - val_fn: 15.0000 - val_accuracy: 0.9989 - val_precision: 0.6569 - val_recall: 0.8171 - val_auc: 0.9519 - val_prc: 0.7255 Epoch 2/100 90/90 [==============================] - 0s 5ms/step - loss: 0.6152 - cross entropy: 0.0319 - Brier score: 0.0059 - tp: 202.0000 - fp: 1117.0000 - tn: 180860.0000 - fn: 97.0000 - accuracy: 0.9933 - precision: 0.1531 - recall: 0.6756 - auc: 0.9051 - prc: 0.4410 - val_loss: 0.0172 - val_cross entropy: 0.0172 - val_Brier score: 0.0019 - val_tp: 69.0000 - val_fp: 72.0000 - val_tn: 45415.0000 - val_fn: 13.0000 - val_accuracy: 0.9981 - val_precision: 0.4894 - val_recall: 0.8415 - val_auc: 0.9577 - val_prc: 0.7220 Epoch 3/100 90/90 [==============================] - 0s 5ms/step - loss: 0.4397 - cross entropy: 0.0461 - Brier score: 0.0095 - tp: 229.0000 - fp: 1929.0000 - tn: 180048.0000 - fn: 70.0000 - accuracy: 0.9890 - precision: 0.1061 - recall: 0.7659 - auc: 0.9307 - prc: 0.4134 - val_loss: 0.0236 - val_cross entropy: 0.0236 - val_Brier score: 0.0029 - val_tp: 69.0000 - val_fp: 106.0000 - val_tn: 45381.0000 - val_fn: 13.0000 - val_accuracy: 0.9974 - val_precision: 0.3943 - val_recall: 0.8415 - val_auc: 0.9662 - val_prc: 0.7291 Epoch 4/100 90/90 [==============================] - 0s 5ms/step - loss: 0.4155 - cross entropy: 0.0619 - Brier score: 0.0136 - tp: 231.0000 - fp: 2898.0000 - tn: 179079.0000 - fn: 68.0000 - accuracy: 0.9837 - precision: 0.0738 - recall: 0.7726 - auc: 0.9272 - prc: 0.3804 - val_loss: 0.0319 - val_cross entropy: 0.0319 - val_Brier score: 0.0046 - val_tp: 70.0000 - val_fp: 188.0000 - val_tn: 45299.0000 - val_fn: 12.0000 - val_accuracy: 0.9956 - val_precision: 0.2713 - val_recall: 0.8537 - val_auc: 0.9697 - val_prc: 0.7095 Epoch 5/100 90/90 [==============================] - 0s 5ms/step - loss: 0.3247 - cross entropy: 0.0773 - Brier score: 0.0178 - tp: 241.0000 - fp: 3872.0000 - tn: 178105.0000 - fn: 58.0000 - accuracy: 0.9784 - precision: 0.0586 - recall: 0.8060 - auc: 0.9471 - prc: 0.3673 - val_loss: 0.0405 - val_cross entropy: 0.0405 - val_Brier score: 0.0068 - val_tp: 71.0000 - val_fp: 334.0000 - val_tn: 45153.0000 - val_fn: 11.0000 - val_accuracy: 0.9924 - val_precision: 0.1753 - val_recall: 0.8659 - val_auc: 0.9714 - val_prc: 0.6518 Epoch 6/100 90/90 [==============================] - 0s 6ms/step - loss: 0.3481 - cross entropy: 0.0976 - Brier score: 0.0225 - tp: 248.0000 - fp: 4880.0000 - tn: 177097.0000 - fn: 51.0000 - accuracy: 0.9729 - precision: 0.0484 - recall: 0.8294 - auc: 0.9351 - prc: 0.3069 - val_loss: 0.0494 - val_cross entropy: 0.0494 - val_Brier score: 0.0093 - val_tp: 73.0000 - val_fp: 511.0000 - val_tn: 44976.0000 - val_fn: 9.0000 - val_accuracy: 0.9886 - val_precision: 0.1250 - val_recall: 0.8902 - val_auc: 0.9742 - val_prc: 0.6313 Epoch 7/100 90/90 [==============================] - 0s 5ms/step - loss: 0.2719 - cross entropy: 0.1078 - Brier score: 0.0253 - tp: 257.0000 - fp: 5673.0000 - tn: 176304.0000 - fn: 42.0000 - accuracy: 0.9686 - precision: 0.0433 - recall: 0.8595 - auc: 0.9564 - prc: 0.2894 - val_loss: 0.0565 - val_cross entropy: 0.0565 - val_Brier score: 0.0112 - val_tp: 73.0000 - val_fp: 633.0000 - val_tn: 44854.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1034 - val_recall: 0.8902 - val_auc: 0.9757 - val_prc: 0.6267 Epoch 8/100 90/90 [==============================] - 0s 6ms/step - loss: 0.2623 - cross entropy: 0.1179 - Brier score: 0.0275 - tp: 262.0000 - fp: 6123.0000 - tn: 175854.0000 - fn: 37.0000 - accuracy: 0.9662 - precision: 0.0410 - recall: 0.8763 - auc: 0.9554 - prc: 0.2609 - val_loss: 0.0607 - val_cross entropy: 0.0607 - val_Brier score: 0.0124 - val_tp: 73.0000 - val_fp: 686.0000 - val_tn: 44801.0000 - val_fn: 9.0000 - val_accuracy: 0.9847 - val_precision: 0.0962 - val_recall: 0.8902 - val_auc: 0.9754 - val_prc: 0.6069 Epoch 9/100 90/90 [==============================] - 0s 5ms/step - loss: 0.2915 - cross entropy: 0.1184 - Brier score: 0.0280 - tp: 257.0000 - fp: 6295.0000 - tn: 175682.0000 - fn: 42.0000 - accuracy: 0.9652 - precision: 0.0392 - recall: 0.8595 - auc: 0.9494 - prc: 0.2652 - val_loss: 0.0653 - val_cross entropy: 0.0653 - val_Brier score: 0.0135 - val_tp: 74.0000 - val_fp: 742.0000 - val_tn: 44745.0000 - val_fn: 8.0000 - val_accuracy: 0.9835 - val_precision: 0.0907 - val_recall: 0.9024 - val_auc: 0.9773 - val_prc: 0.5856 Epoch 10/100 90/90 [==============================] - 0s 6ms/step - loss: 0.2632 - cross entropy: 0.1336 - Brier score: 0.0313 - tp: 259.0000 - fp: 6976.0000 - tn: 175001.0000 - fn: 40.0000 - accuracy: 0.9615 - precision: 0.0358 - recall: 0.8662 - auc: 0.9561 - prc: 0.2365 - val_loss: 0.0700 - val_cross entropy: 0.0700 - val_Brier score: 0.0146 - val_tp: 76.0000 - val_fp: 801.0000 - val_tn: 44686.0000 - val_fn: 6.0000 - val_accuracy: 0.9823 - val_precision: 0.0867 - val_recall: 0.9268 - val_auc: 0.9773 - val_prc: 0.5876 Epoch 11/100 90/90 [==============================] - 1s 6ms/step - loss: 0.2336 - cross entropy: 0.1282 - Brier score: 0.0299 - tp: 269.0000 - fp: 6690.0000 - tn: 175287.0000 - fn: 30.0000 - accuracy: 0.9631 - precision: 0.0387 - recall: 0.8997 - auc: 0.9586 - prc: 0.2494 - val_loss: 0.0679 - val_cross entropy: 0.0679 - val_Brier score: 0.0140 - val_tp: 76.0000 - val_fp: 757.0000 - val_tn: 44730.0000 - val_fn: 6.0000 - val_accuracy: 0.9833 - val_precision: 0.0912 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.5891 Epoch 12/100 90/90 [==============================] - 1s 6ms/step - loss: 0.2399 - cross entropy: 0.1289 - Brier score: 0.0298 - tp: 265.0000 - fp: 6654.0000 - tn: 175323.0000 - fn: 34.0000 - accuracy: 0.9633 - precision: 0.0383 - recall: 0.8863 - auc: 0.9602 - prc: 0.2601 - val_loss: 0.0684 - val_cross entropy: 0.0684 - val_Brier score: 0.0141 - val_tp: 76.0000 - val_fp: 762.0000 - val_tn: 44725.0000 - val_fn: 6.0000 - val_accuracy: 0.9831 - val_precision: 0.0907 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.5848 Epoch 13/100 79/90 [=========================>....] - ETA: 0s - loss: 0.2286 - cross entropy: 0.1265 - Brier score: 0.0295 - tp: 237.0000 - fp: 5838.0000 - tn: 155684.0000 - fn: 33.0000 - accuracy: 0.9637 - precision: 0.0390 - recall: 0.8778 - auc: 0.9696 - prc: 0.2645Restoring model weights from the end of the best epoch: 3. 90/90 [==============================] - 1s 6ms/step - loss: 0.2341 - cross entropy: 0.1275 - Brier score: 0.0297 - tp: 262.0000 - fp: 6631.0000 - tn: 175346.0000 - fn: 37.0000 - accuracy: 0.9634 - precision: 0.0380 - recall: 0.8763 - auc: 0.9665 - prc: 0.2538 - val_loss: 0.0757 - val_cross entropy: 0.0757 - val_Brier score: 0.0159 - val_tp: 76.0000 - val_fp: 834.0000 - val_tn: 44653.0000 - val_fn: 6.0000 - val_accuracy: 0.9816 - val_precision: 0.0835 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.5709 Epoch 13: early stopping
检查训练历史
plot_metrics(weighted_history)
评估指标
train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step 28/28 [==============================] - 0s 1ms/step
weighted_results = weighted_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_weighted)
loss : 0.024716919288039207 cross entropy : 0.024716919288039207 Brier score : 0.0029473488684743643 tp : 88.0 fp : 134.0 tn : 56717.0 fn : 23.0 accuracy : 0.9972437620162964 precision : 0.3963963985443115 recall : 0.792792797088623 auc : 0.9477326273918152 prc : 0.6732124090194702 Legitimate Transactions Detected (True Negatives): 56717 Legitimate Transactions Incorrectly Detected (False Positives): 134 Fraudulent Transactions Missed (False Negatives): 23 Fraudulent Transactions Detected (True Positives): 88 Total Fraudulent Transactions: 111
在这里你可以看到,使用类别权重,准确率和精确率较低,因为有更多假阳性,但相反,召回率和 AUC 较高,因为模型也找到了更多真阳性。尽管准确率较低,但该模型具有更高的召回率(并且在阈值为 50% 时比基线模型识别出更多欺诈交易)。当然,两种类型的错误都有成本(你也不希望通过将太多合法交易标记为欺诈来打扰用户)。仔细考虑这些不同类型的错误对你应用的权衡。
与改变阈值的基线模型相比,类别加权模型明显较差。基线模型的优越性通过较低的测试损失值(交叉熵和均方误差)得到进一步证实,此外还可以通过将两个模型的 ROC 曲线一起绘制来观察。
绘制 ROC
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plt.legend(loc='lower right');
绘制 PRC
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plt.legend(loc='lower right');
过采样
过采样少数类
另一种相关的方法是对数据集进行重采样,对少数类进行过采样。
pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]
pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
使用 NumPy
你可以通过从正样本中选择正确的随机索引数量来手动平衡数据集
ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))
res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]
res_pos_features.shape
(181977, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]
resampled_features.shape
(363954, 29)
使用 tf.data
如果你正在使用 tf.data
,生成平衡样本的最简单方法是从一个 positive
和一个 negative
数据集开始,并将它们合并。有关更多示例,请参阅 tf.data 指南。
BUFFER_SIZE = 100000
def make_ds(features, labels):
ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
ds = ds.shuffle(BUFFER_SIZE).repeat()
return ds
pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)
每个数据集都提供 (feature, label)
对
for features, label in pos_ds.take(1):
print("Features:\n", features.numpy())
print()
print("Label: ", label.numpy())
Features: [ 4.57437149e-03 1.41282803e+00 -1.70738347e+00 7.86145002e-01 2.34322123e+00 -1.32760854e+00 1.68238195e+00 -7.10272314e-01 8.18760297e-01 -3.09684905e+00 2.01295966e+00 -3.98984767e+00 1.02827419e+00 -5.00000000e+00 -1.25820263e+00 1.91494135e+00 5.00000000e+00 3.32009026e+00 -2.75342824e+00 -8.47588695e-03 -7.83382558e-01 -1.24259811e+00 -6.45039879e-01 -1.71393384e-02 1.13211907e+00 -1.52256293e+00 -1.08919872e+00 -1.06657977e+00 -1.45889491e+00] Label: 1
使用 tf.data.Dataset.sample_from_datasets
将两者合并在一起
resampled_ds = tf.data.Dataset.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
print(label.numpy().mean())
0.50341796875
要使用此数据集,你需要每个 epoch 的步数。
在这种情况下,“epoch”的定义不太清楚。假设它是看到每个负样本一次所需的批次数量
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0
在过采样数据上训练
现在尝试使用重采样数据集而不是使用类别权重来训练模型,看看这些方法如何比较。
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)
resampled_history = resampled_model.fit(
resampled_ds,
epochs=EPOCHS,
steps_per_epoch=resampled_steps_per_epoch,
callbacks=[early_stopping],
validation_data=val_ds)
Epoch 1/100 278/278 [==============================] - 8s 22ms/step - loss: 0.3612 - cross entropy: 0.3306 - Brier score: 0.1096 - tp: 258749.0000 - fp: 76958.0000 - tn: 264513.0000 - fn: 26086.0000 - accuracy: 0.8355 - precision: 0.7708 - recall: 0.9084 - auc: 0.9490 - prc: 0.9536 - val_loss: 0.2021 - val_cross entropy: 0.2021 - val_Brier score: 0.0446 - val_tp: 75.0000 - val_fp: 1144.0000 - val_tn: 44343.0000 - val_fn: 7.0000 - val_accuracy: 0.9747 - val_precision: 0.0615 - val_recall: 0.9146 - val_auc: 0.9741 - val_prc: 0.7919 Epoch 2/100 278/278 [==============================] - 5s 20ms/step - loss: 0.1757 - cross entropy: 0.1757 - Brier score: 0.0515 - tp: 262968.0000 - fp: 15885.0000 - tn: 269124.0000 - fn: 21367.0000 - accuracy: 0.9346 - precision: 0.9430 - recall: 0.9249 - auc: 0.9817 - prc: 0.9852 - val_loss: 0.1003 - val_cross entropy: 0.1003 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 858.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9810 - val_precision: 0.0814 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.7702 Epoch 3/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1358 - cross entropy: 0.1358 - Brier score: 0.0398 - tp: 266700.0000 - fp: 11006.0000 - tn: 273451.0000 - fn: 18187.0000 - accuracy: 0.9487 - precision: 0.9604 - recall: 0.9362 - auc: 0.9891 - prc: 0.9904 - val_loss: 0.0725 - val_cross entropy: 0.0725 - val_Brier score: 0.0158 - val_tp: 76.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 6.0000 - val_accuracy: 0.9825 - val_precision: 0.0878 - val_recall: 0.9268 - val_auc: 0.9766 - val_prc: 0.7553 Epoch 4/100 278/278 [==============================] - 6s 21ms/step - loss: 0.1151 - cross entropy: 0.1151 - Brier score: 0.0341 - tp: 269190.0000 - fp: 9719.0000 - tn: 274441.0000 - fn: 15994.0000 - accuracy: 0.9548 - precision: 0.9652 - recall: 0.9439 - auc: 0.9925 - prc: 0.9930 - val_loss: 0.0596 - val_cross entropy: 0.0596 - val_Brier score: 0.0136 - val_tp: 76.0000 - val_fp: 725.0000 - val_tn: 44762.0000 - val_fn: 6.0000 - val_accuracy: 0.9840 - val_precision: 0.0949 - val_recall: 0.9268 - val_auc: 0.9726 - val_prc: 0.7292 Epoch 5/100 278/278 [==============================] - 6s 20ms/step - loss: 0.1006 - cross entropy: 0.1006 - Brier score: 0.0299 - tp: 270949.0000 - fp: 8853.0000 - tn: 275916.0000 - fn: 13626.0000 - accuracy: 0.9605 - precision: 0.9684 - recall: 0.9521 - auc: 0.9945 - prc: 0.9946 - val_loss: 0.0525 - val_cross entropy: 0.0525 - val_Brier score: 0.0124 - val_tp: 76.0000 - val_fp: 668.0000 - val_tn: 44819.0000 - val_fn: 6.0000 - val_accuracy: 0.9852 - val_precision: 0.1022 - val_recall: 0.9268 - val_auc: 0.9717 - val_prc: 0.7216 Epoch 6/100 278/278 [==============================] - 6s 20ms/step - loss: 0.0904 - cross entropy: 0.0904 - Brier score: 0.0268 - tp: 272681.0000 - fp: 8122.0000 - tn: 276344.0000 - fn: 12197.0000 - accuracy: 0.9643 - precision: 0.9711 - recall: 0.9572 - auc: 0.9958 - prc: 0.9956 - val_loss: 0.0456 - val_cross entropy: 0.0456 - val_Brier score: 0.0108 - val_tp: 76.0000 - val_fp: 576.0000 - val_tn: 44911.0000 - val_fn: 6.0000 - val_accuracy: 0.9872 - val_precision: 0.1166 - val_recall: 0.9268 - val_auc: 0.9737 - val_prc: 0.7304 Epoch 7/100 278/278 [==============================] - 6s 20ms/step - loss: 0.0828 - cross entropy: 0.0828 - Brier score: 0.0244 - tp: 273911.0000 - fp: 7426.0000 - tn: 277008.0000 - fn: 10999.0000 - accuracy: 0.9676 - precision: 0.9736 - recall: 0.9614 - auc: 0.9965 - prc: 0.9963 - val_loss: 0.0408 - val_cross entropy: 0.0408 - val_Brier score: 0.0099 - val_tp: 77.0000 - val_fp: 546.0000 - val_tn: 44941.0000 - val_fn: 5.0000 - val_accuracy: 0.9879 - val_precision: 0.1236 - val_recall: 0.9390 - val_auc: 0.9752 - val_prc: 0.7232 Epoch 8/100 278/278 [==============================] - 5s 20ms/step - loss: 0.0775 - cross entropy: 0.0775 - Brier score: 0.0228 - tp: 274985.0000 - fp: 6904.0000 - tn: 277146.0000 - fn: 10309.0000 - accuracy: 0.9698 - precision: 0.9755 - recall: 0.9639 - auc: 0.9970 - prc: 0.9968 - val_loss: 0.0387 - val_cross entropy: 0.0387 - val_Brier score: 0.0096 - val_tp: 77.0000 - val_fp: 568.0000 - val_tn: 44919.0000 - val_fn: 5.0000 - val_accuracy: 0.9874 - val_precision: 0.1194 - val_recall: 0.9390 - val_auc: 0.9761 - val_prc: 0.7145 Epoch 9/100 278/278 [==============================] - 5s 20ms/step - loss: 0.0743 - cross entropy: 0.0743 - Brier score: 0.0219 - tp: 274086.0000 - fp: 6704.0000 - tn: 278828.0000 - fn: 9726.0000 - accuracy: 0.9711 - precision: 0.9761 - recall: 0.9657 - auc: 0.9971 - prc: 0.9969 - val_loss: 0.0344 - val_cross entropy: 0.0344 - val_Brier score: 0.0085 - val_tp: 76.0000 - val_fp: 492.0000 - val_tn: 44995.0000 - val_fn: 6.0000 - val_accuracy: 0.9891 - val_precision: 0.1338 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7147 Epoch 10/100 278/278 [==============================] - 5s 20ms/step - loss: 0.0712 - cross entropy: 0.0712 - Brier score: 0.0211 - tp: 275221.0000 - fp: 6399.0000 - tn: 278199.0000 - fn: 9525.0000 - accuracy: 0.9720 - precision: 0.9773 - recall: 0.9665 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0311 - val_cross entropy: 0.0311 - val_Brier score: 0.0077 - val_tp: 76.0000 - val_fp: 434.0000 - val_tn: 45053.0000 - val_fn: 6.0000 - val_accuracy: 0.9903 - val_precision: 0.1490 - val_recall: 0.9268 - val_auc: 0.9772 - val_prc: 0.7140 Epoch 11/100 276/278 [============================>.] - ETA: 0s - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 273841.0000 - fp: 6329.0000 - tn: 275888.0000 - fn: 9190.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970Restoring model weights from the end of the best epoch: 1. 278/278 [==============================] - 5s 20ms/step - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 275842.0000 - fp: 6384.0000 - tn: 277849.0000 - fn: 9269.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0302 - val_cross entropy: 0.0302 - val_Brier score: 0.0075 - val_tp: 76.0000 - val_fp: 433.0000 - val_tn: 45054.0000 - val_fn: 6.0000 - val_accuracy: 0.9904 - val_precision: 0.1493 - val_recall: 0.9268 - val_auc: 0.9775 - val_prc: 0.7154 Epoch 11: early stopping
如果训练过程在每次梯度更新时都考虑整个数据集,那么这种过采样将与类别加权基本相同。
但是,当你像这里一样分批训练模型时,过采样数据提供了更平滑的梯度信号:每个正样本不是在一个批次中以较大的权重显示,而是在许多不同的批次中以较小的权重显示。
这种更平滑的梯度信号使训练模型变得更容易。
检查训练历史
请注意,指标的分布在这里将不同,因为训练数据的分布与验证和测试数据完全不同。
plot_metrics(resampled_history)
重新训练
由于在平衡数据上训练更容易,因此上述训练过程可能会很快过拟合。
因此,将 epoch 分开,以便 tf.keras.callbacks.EarlyStopping
对何时停止训练有更精细的控制。
resampled_model = make_model()
resampled_model.load_weights(initial_weights)
# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1]
output_layer.bias.assign([0])
resampled_history = resampled_model.fit(
resampled_ds,
# These are not real epochs
steps_per_epoch=20,
epochs=10*EPOCHS,
callbacks=[early_stopping],
validation_data=(val_ds))
Epoch 1/1000 20/20 [==============================] - 2s 47ms/step - loss: 0.6826 - cross entropy: 0.3390 - Brier score: 0.1176 - tp: 18430.0000 - fp: 14493.0000 - tn: 51299.0000 - fn: 2307.0000 - accuracy: 0.8058 - precision: 0.5598 - recall: 0.8887 - auc: 0.9464 - prc: 0.8794 - val_loss: 1.0004 - val_cross entropy: 1.0004 - val_Brier score: 0.3825 - val_tp: 79.0000 - val_fp: 36434.0000 - val_tn: 9053.0000 - val_fn: 3.0000 - val_accuracy: 0.2004 - val_precision: 0.0022 - val_recall: 0.9634 - val_auc: 0.9261 - val_prc: 0.5749 Epoch 2/1000 20/20 [==============================] - 0s 24ms/step - loss: 0.5885 - cross entropy: 0.5885 - Brier score: 0.2089 - tp: 18504.0000 - fp: 12626.0000 - tn: 7862.0000 - fn: 1968.0000 - accuracy: 0.6437 - precision: 0.5944 - recall: 0.9039 - auc: 0.8752 - prc: 0.9121 - val_loss: 0.8348 - val_cross entropy: 0.8348 - val_Brier score: 0.3117 - val_tp: 79.0000 - val_fp: 28884.0000 - val_tn: 16603.0000 - val_fn: 3.0000 - val_accuracy: 0.3661 - val_precision: 0.0027 - val_recall: 0.9634 - val_auc: 0.9395 - val_prc: 0.6709 Epoch 3/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.5070 - cross entropy: 0.5070 - Brier score: 0.1790 - tp: 18418.0000 - fp: 10425.0000 - tn: 10193.0000 - fn: 1924.0000 - accuracy: 0.6985 - precision: 0.6386 - recall: 0.9054 - auc: 0.8991 - prc: 0.9280 - val_loss: 0.6975 - val_cross entropy: 0.6975 - val_Brier score: 0.2495 - val_tp: 78.0000 - val_fp: 19535.0000 - val_tn: 25952.0000 - val_fn: 4.0000 - val_accuracy: 0.5712 - val_precision: 0.0040 - val_recall: 0.9512 - val_auc: 0.9499 - val_prc: 0.7048 Epoch 4/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.4413 - cross entropy: 0.4413 - Brier score: 0.1530 - tp: 18483.0000 - fp: 8228.0000 - tn: 12349.0000 - fn: 1900.0000 - accuracy: 0.7527 - precision: 0.6920 - recall: 0.9068 - auc: 0.9179 - prc: 0.9406 - val_loss: 0.5893 - val_cross entropy: 0.5893 - val_Brier score: 0.1998 - val_tp: 77.0000 - val_fp: 11782.0000 - val_tn: 33705.0000 - val_fn: 5.0000 - val_accuracy: 0.7413 - val_precision: 0.0065 - val_recall: 0.9390 - val_auc: 0.9552 - val_prc: 0.7246 Epoch 5/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.3914 - cross entropy: 0.3914 - Brier score: 0.1335 - tp: 18615.0000 - fp: 6548.0000 - tn: 13896.0000 - fn: 1901.0000 - accuracy: 0.7937 - precision: 0.7398 - recall: 0.9073 - auc: 0.9304 - prc: 0.9500 - val_loss: 0.5045 - val_cross entropy: 0.5045 - val_Brier score: 0.1613 - val_tp: 77.0000 - val_fp: 7135.0000 - val_tn: 38352.0000 - val_fn: 5.0000 - val_accuracy: 0.8433 - val_precision: 0.0107 - val_recall: 0.9390 - val_auc: 0.9595 - val_prc: 0.7424 Epoch 6/1000 20/20 [==============================] - 0s 26ms/step - loss: 0.3563 - cross entropy: 0.3563 - Brier score: 0.1183 - tp: 18429.0000 - fp: 5050.0000 - tn: 15533.0000 - fn: 1948.0000 - accuracy: 0.8292 - precision: 0.7849 - recall: 0.9044 - auc: 0.9391 - prc: 0.9552 - val_loss: 0.4395 - val_cross entropy: 0.4395 - val_Brier score: 0.1328 - val_tp: 77.0000 - val_fp: 4727.0000 - val_tn: 40760.0000 - val_fn: 5.0000 - val_accuracy: 0.8962 - val_precision: 0.0160 - val_recall: 0.9390 - val_auc: 0.9616 - val_prc: 0.7625 Epoch 7/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.3220 - cross entropy: 0.3220 - Brier score: 0.1047 - tp: 18807.0000 - fp: 4065.0000 - tn: 16241.0000 - fn: 1847.0000 - accuracy: 0.8557 - precision: 0.8223 - recall: 0.9106 - auc: 0.9485 - prc: 0.9631 - val_loss: 0.3867 - val_cross entropy: 0.3867 - val_Brier score: 0.1105 - val_tp: 77.0000 - val_fp: 3192.0000 - val_tn: 42295.0000 - val_fn: 5.0000 - val_accuracy: 0.9298 - val_precision: 0.0236 - val_recall: 0.9390 - val_auc: 0.9635 - val_prc: 0.7711 Epoch 8/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.3012 - cross entropy: 0.3012 - Brier score: 0.0959 - tp: 18607.0000 - fp: 3384.0000 - tn: 17165.0000 - fn: 1804.0000 - accuracy: 0.8733 - precision: 0.8461 - recall: 0.9116 - auc: 0.9545 - prc: 0.9661 - val_loss: 0.3438 - val_cross entropy: 0.3438 - val_Brier score: 0.0932 - val_tp: 77.0000 - val_fp: 2361.0000 - val_tn: 43126.0000 - val_fn: 5.0000 - val_accuracy: 0.9481 - val_precision: 0.0316 - val_recall: 0.9390 - val_auc: 0.9644 - val_prc: 0.7748 Epoch 9/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2795 - cross entropy: 0.2795 - Brier score: 0.0880 - tp: 18636.0000 - fp: 2891.0000 - tn: 17616.0000 - fn: 1817.0000 - accuracy: 0.8851 - precision: 0.8657 - recall: 0.9112 - auc: 0.9589 - prc: 0.9692 - val_loss: 0.3087 - val_cross entropy: 0.3087 - val_Brier score: 0.0799 - val_tp: 76.0000 - val_fp: 1892.0000 - val_tn: 43595.0000 - val_fn: 6.0000 - val_accuracy: 0.9583 - val_precision: 0.0386 - val_recall: 0.9268 - val_auc: 0.9658 - val_prc: 0.7797 Epoch 10/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2620 - cross entropy: 0.2620 - Brier score: 0.0812 - tp: 18743.0000 - fp: 2432.0000 - tn: 17955.0000 - fn: 1830.0000 - accuracy: 0.8959 - precision: 0.8851 - recall: 0.9110 - auc: 0.9625 - prc: 0.9724 - val_loss: 0.2798 - val_cross entropy: 0.2798 - val_Brier score: 0.0695 - val_tp: 76.0000 - val_fp: 1615.0000 - val_tn: 43872.0000 - val_fn: 6.0000 - val_accuracy: 0.9644 - val_precision: 0.0449 - val_recall: 0.9268 - val_auc: 0.9674 - val_prc: 0.7834 Epoch 11/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2480 - cross entropy: 0.2480 - Brier score: 0.0757 - tp: 18645.0000 - fp: 2154.0000 - tn: 18383.0000 - fn: 1778.0000 - accuracy: 0.9040 - precision: 0.8964 - recall: 0.9129 - auc: 0.9668 - prc: 0.9748 - val_loss: 0.2551 - val_cross entropy: 0.2551 - val_Brier score: 0.0611 - val_tp: 76.0000 - val_fp: 1428.0000 - val_tn: 44059.0000 - val_fn: 6.0000 - val_accuracy: 0.9685 - val_precision: 0.0505 - val_recall: 0.9268 - val_auc: 0.9691 - val_prc: 0.7858 Epoch 12/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2368 - cross entropy: 0.2368 - Brier score: 0.0722 - tp: 18706.0000 - fp: 1922.0000 - tn: 18565.0000 - fn: 1767.0000 - accuracy: 0.9099 - precision: 0.9068 - recall: 0.9137 - auc: 0.9682 - prc: 0.9759 - val_loss: 0.2341 - val_cross entropy: 0.2341 - val_Brier score: 0.0543 - val_tp: 75.0000 - val_fp: 1301.0000 - val_tn: 44186.0000 - val_fn: 7.0000 - val_accuracy: 0.9713 - val_precision: 0.0545 - val_recall: 0.9146 - val_auc: 0.9710 - val_prc: 0.7888 Epoch 13/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2223 - cross entropy: 0.2223 - Brier score: 0.0667 - tp: 18874.0000 - fp: 1694.0000 - tn: 18675.0000 - fn: 1717.0000 - accuracy: 0.9167 - precision: 0.9176 - recall: 0.9166 - auc: 0.9720 - prc: 0.9785 - val_loss: 0.2162 - val_cross entropy: 0.2162 - val_Brier score: 0.0488 - val_tp: 75.0000 - val_fp: 1235.0000 - val_tn: 44252.0000 - val_fn: 7.0000 - val_accuracy: 0.9727 - val_precision: 0.0573 - val_recall: 0.9146 - val_auc: 0.9732 - val_prc: 0.7912 Epoch 14/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.2172 - cross entropy: 0.2172 - Brier score: 0.0648 - tp: 18681.0000 - fp: 1627.0000 - tn: 18898.0000 - fn: 1754.0000 - accuracy: 0.9175 - precision: 0.9199 - recall: 0.9142 - auc: 0.9732 - prc: 0.9789 - val_loss: 0.2011 - val_cross entropy: 0.2011 - val_Brier score: 0.0444 - val_tp: 75.0000 - val_fp: 1167.0000 - val_tn: 44320.0000 - val_fn: 7.0000 - val_accuracy: 0.9742 - val_precision: 0.0604 - val_recall: 0.9146 - val_auc: 0.9748 - val_prc: 0.7927 Epoch 15/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2088 - cross entropy: 0.2088 - Brier score: 0.0619 - tp: 18878.0000 - fp: 1484.0000 - tn: 18949.0000 - fn: 1649.0000 - accuracy: 0.9235 - precision: 0.9271 - recall: 0.9197 - auc: 0.9749 - prc: 0.9806 - val_loss: 0.1872 - val_cross entropy: 0.1872 - val_Brier score: 0.0405 - val_tp: 75.0000 - val_fp: 1100.0000 - val_tn: 44387.0000 - val_fn: 7.0000 - val_accuracy: 0.9757 - val_precision: 0.0638 - val_recall: 0.9146 - val_auc: 0.9760 - val_prc: 0.7931 Epoch 16/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.2011 - cross entropy: 0.2011 - Brier score: 0.0596 - tp: 18797.0000 - fp: 1439.0000 - tn: 19068.0000 - fn: 1656.0000 - accuracy: 0.9244 - precision: 0.9289 - recall: 0.9190 - auc: 0.9768 - prc: 0.9818 - val_loss: 0.1743 - val_cross entropy: 0.1743 - val_Brier score: 0.0369 - val_tp: 75.0000 - val_fp: 1029.0000 - val_tn: 44458.0000 - val_fn: 7.0000 - val_accuracy: 0.9773 - val_precision: 0.0679 - val_recall: 0.9146 - val_auc: 0.9769 - val_prc: 0.7935 Epoch 17/1000 20/20 [==============================] - 0s 26ms/step - loss: 0.1961 - cross entropy: 0.1961 - Brier score: 0.0575 - tp: 18762.0000 - fp: 1337.0000 - tn: 19238.0000 - fn: 1623.0000 - accuracy: 0.9277 - precision: 0.9335 - recall: 0.9204 - auc: 0.9773 - prc: 0.9821 - val_loss: 0.1636 - val_cross entropy: 0.1636 - val_Brier score: 0.0342 - val_tp: 75.0000 - val_fp: 997.0000 - val_tn: 44490.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0700 - val_recall: 0.9146 - val_auc: 0.9777 - val_prc: 0.7943 Epoch 18/1000 20/20 [==============================] - 0s 26ms/step - loss: 0.1891 - cross entropy: 0.1891 - Brier score: 0.0554 - tp: 18751.0000 - fp: 1286.0000 - tn: 19292.0000 - fn: 1631.0000 - accuracy: 0.9288 - precision: 0.9358 - recall: 0.9200 - auc: 0.9789 - prc: 0.9833 - val_loss: 0.1544 - val_cross entropy: 0.1544 - val_Brier score: 0.0320 - val_tp: 75.0000 - val_fp: 981.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9783 - val_precision: 0.0710 - val_recall: 0.9146 - val_auc: 0.9780 - val_prc: 0.7971 Epoch 19/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.1833 - cross entropy: 0.1833 - Brier score: 0.0534 - tp: 18789.0000 - fp: 1144.0000 - tn: 19432.0000 - fn: 1595.0000 - accuracy: 0.9331 - precision: 0.9426 - recall: 0.9218 - auc: 0.9802 - prc: 0.9842 - val_loss: 0.1461 - val_cross entropy: 0.1461 - val_Brier score: 0.0300 - val_tp: 76.0000 - val_fp: 949.0000 - val_tn: 44538.0000 - val_fn: 6.0000 - val_accuracy: 0.9790 - val_precision: 0.0741 - val_recall: 0.9268 - val_auc: 0.9782 - val_prc: 0.7972 Epoch 20/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.1775 - cross entropy: 0.1775 - Brier score: 0.0517 - tp: 18845.0000 - fp: 1120.0000 - tn: 19463.0000 - fn: 1532.0000 - accuracy: 0.9353 - precision: 0.9439 - recall: 0.9248 - auc: 0.9814 - prc: 0.9849 - val_loss: 0.1394 - val_cross entropy: 0.1394 - val_Brier score: 0.0287 - val_tp: 76.0000 - val_fp: 969.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9786 - val_precision: 0.0727 - val_recall: 0.9268 - val_auc: 0.9788 - val_prc: 0.7971 Epoch 21/1000 20/20 [==============================] - 1s 26ms/step - loss: 0.1727 - cross entropy: 0.1727 - Brier score: 0.0506 - tp: 19042.0000 - fp: 1056.0000 - tn: 19310.0000 - fn: 1552.0000 - accuracy: 0.9363 - precision: 0.9475 - recall: 0.9246 - auc: 0.9818 - prc: 0.9855 - val_loss: 0.1331 - val_cross entropy: 0.1331 - val_Brier score: 0.0274 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.7973 Epoch 22/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.1711 - cross entropy: 0.1711 - Brier score: 0.0501 - tp: 19041.0000 - fp: 1102.0000 - tn: 19283.0000 - fn: 1534.0000 - accuracy: 0.9356 - precision: 0.9453 - recall: 0.9254 - auc: 0.9826 - prc: 0.9859 - val_loss: 0.1275 - val_cross entropy: 0.1275 - val_Brier score: 0.0262 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.7879 Epoch 23/1000 20/20 [==============================] - 1s 27ms/step - loss: 0.1657 - cross entropy: 0.1657 - Brier score: 0.0479 - tp: 19074.0000 - fp: 1045.0000 - tn: 19372.0000 - fn: 1469.0000 - accuracy: 0.9386 - precision: 0.9481 - recall: 0.9285 - auc: 0.9838 - prc: 0.9867 - val_loss: 0.1215 - val_cross entropy: 0.1215 - val_Brier score: 0.0249 - val_tp: 76.0000 - val_fp: 939.0000 - val_tn: 44548.0000 - val_fn: 6.0000 - val_accuracy: 0.9793 - val_precision: 0.0749 - val_recall: 0.9268 - val_auc: 0.9785 - val_prc: 0.7882 Epoch 24/1000 20/20 [==============================] - 0s 26ms/step - loss: 0.1631 - cross entropy: 0.1631 - Brier score: 0.0478 - tp: 19006.0000 - fp: 1055.0000 - tn: 19442.0000 - fn: 1457.0000 - accuracy: 0.9387 - precision: 0.9474 - recall: 0.9288 - auc: 0.9839 - prc: 0.9868 - val_loss: 0.1166 - val_cross entropy: 0.1166 - val_Brier score: 0.0239 - val_tp: 76.0000 - val_fp: 924.0000 - val_tn: 44563.0000 - val_fn: 6.0000 - val_accuracy: 0.9796 - val_precision: 0.0760 - val_recall: 0.9268 - val_auc: 0.9780 - val_prc: 0.7886 Epoch 25/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.1586 - cross entropy: 0.1586 - Brier score: 0.0464 - tp: 19058.0000 - fp: 971.0000 - tn: 19476.0000 - fn: 1455.0000 - accuracy: 0.9408 - precision: 0.9515 - recall: 0.9291 - auc: 0.9847 - prc: 0.9875 - val_loss: 0.1119 - val_cross entropy: 0.1119 - val_Brier score: 0.0229 - val_tp: 76.0000 - val_fp: 908.0000 - val_tn: 44579.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0772 - val_recall: 0.9268 - val_auc: 0.9783 - val_prc: 0.7886 Epoch 26/1000 20/20 [==============================] - 0s 26ms/step - loss: 0.1568 - cross entropy: 0.1568 - Brier score: 0.0459 - tp: 18807.0000 - fp: 974.0000 - tn: 19740.0000 - fn: 1439.0000 - accuracy: 0.9411 - precision: 0.9508 - recall: 0.9289 - auc: 0.9851 - prc: 0.9874 - val_loss: 0.1072 - val_cross entropy: 0.1072 - val_Brier score: 0.0219 - val_tp: 76.0000 - val_fp: 881.0000 - val_tn: 44606.0000 - val_fn: 6.0000 - val_accuracy: 0.9805 - val_precision: 0.0794 - val_recall: 0.9268 - val_auc: 0.9779 - val_prc: 0.7889 Epoch 27/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.1562 - cross entropy: 0.1562 - Brier score: 0.0457 - tp: 19045.0000 - fp: 1010.0000 - tn: 19477.0000 - fn: 1428.0000 - accuracy: 0.9405 - precision: 0.9496 - recall: 0.9302 - auc: 0.9854 - prc: 0.9876 - val_loss: 0.1032 - val_cross entropy: 0.1032 - val_Brier score: 0.0211 - val_tp: 76.0000 - val_fp: 864.0000 - val_tn: 44623.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0809 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7704 Epoch 28/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.1525 - cross entropy: 0.1525 - Brier score: 0.0442 - tp: 19016.0000 - fp: 881.0000 - tn: 19650.0000 - fn: 1413.0000 - accuracy: 0.9440 - precision: 0.9557 - recall: 0.9308 - auc: 0.9862 - prc: 0.9882 - val_loss: 0.0998 - val_cross entropy: 0.0998 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 866.0000 - val_tn: 44621.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0807 - val_recall: 0.9268 - val_auc: 0.9778 - val_prc: 0.7706 Epoch 29/1000 20/20 [==============================] - 0s 24ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0429 - tp: 19105.0000 - fp: 852.0000 - tn: 19596.0000 - fn: 1407.0000 - accuracy: 0.9448 - precision: 0.9573 - recall: 0.9314 - auc: 0.9870 - prc: 0.9891 - val_loss: 0.0968 - val_cross entropy: 0.0968 - val_Brier score: 0.0200 - val_tp: 76.0000 - val_fp: 868.0000 - val_tn: 44619.0000 - val_fn: 6.0000 - val_accuracy: 0.9808 - val_precision: 0.0805 - val_recall: 0.9268 - val_auc: 0.9770 - val_prc: 0.7709 Epoch 30/1000 20/20 [==============================] - 0s 25ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0431 - tp: 19112.0000 - fp: 860.0000 - tn: 19584.0000 - fn: 1404.0000 - accuracy: 0.9447 - precision: 0.9569 - recall: 0.9316 - auc: 0.9867 - prc: 0.9888 - val_loss: 0.0941 - val_cross entropy: 0.0941 - val_Brier score: 0.0195 - val_tp: 76.0000 - val_fp: 850.0000 - val_tn: 44637.0000 - val_fn: 6.0000 - val_accuracy: 0.9812 - val_precision: 0.0821 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7712 Epoch 31/1000 20/20 [==============================] - ETA: 0s - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893Restoring model weights from the end of the best epoch: 21. 20/20 [==============================] - 0s 25ms/step - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893 - val_loss: 0.0912 - val_cross entropy: 0.0912 - val_Brier score: 0.0189 - val_tp: 76.0000 - val_fp: 826.0000 - val_tn: 44661.0000 - val_fn: 6.0000 - val_accuracy: 0.9817 - val_precision: 0.0843 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7622 Epoch 31: early stopping
重新检查训练历史
plot_metrics(resampled_history)
评估指标
train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step 28/28 [==============================] - 0s 1ms/step
resampled_results = resampled_model.evaluate(test_features, test_labels,
batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss : 0.13269135355949402 cross entropy : 0.13269135355949402 Brier score : 0.02699681930243969 tp : 96.0 fp : 1177.0 tn : 55674.0 fn : 15.0 accuracy : 0.9790737628936768 precision : 0.07541241496801376 recall : 0.8648648858070374 auc : 0.9722627401351929 prc : 0.703483521938324 Legitimate Transactions Detected (True Negatives): 55674 Legitimate Transactions Incorrectly Detected (False Positives): 1177 Fraudulent Transactions Missed (False Negatives): 15 Fraudulent Transactions Detected (True Positives): 96 Total Fraudulent Transactions: 111
绘制 ROC
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');
绘制 AUPRC
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');
将本教程应用于你的问题
不平衡数据分类是一项固有的困难任务,因为可供学习的样本很少。你应该始终先从数据开始,尽力收集尽可能多的样本,并认真思考哪些特征可能相关,以便模型能够充分利用你的少数类。在某些时候,你的模型可能会难以改进并产生你想要的结果,因此重要的是要牢记问题的背景以及不同类型错误之间的权衡。