不平衡数据的分类

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

本教程演示了如何在高度不平衡的数据集中进行分类,其中一个类别的示例数量远远超过另一个类别的示例数量。您将使用在 Kaggle 上托管的 信用卡欺诈检测 数据集。目标是从总共 284,807 笔交易中检测出仅仅 492 笔欺诈性交易。您将使用 Keras 定义模型,并使用 类别权重 帮助模型从不平衡数据中学习。.

本教程包含完成以下操作的完整代码

  • 使用 Pandas 加载 CSV 文件。
  • 创建训练集、验证集和测试集。
  • 使用 Keras 定义和训练模型(包括设置类别权重)。
  • 使用各种指标(包括精确率和召回率)评估模型。
  • 为概率分类器选择一个阈值以获得确定性分类器。
  • 尝试并比较类别加权建模和过采样。

设置

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2024-01-17 02:20:29.309180: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-17 02:20:29.309224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-17 02:20:29.310677: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

数据处理和探索

下载 Kaggle 信用卡欺诈数据集

Pandas 是一个 Python 库,它提供了许多用于加载和处理结构化数据的实用程序。它可以用于将 CSV 下载到 Pandas DataFrame 中。

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()
raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

检查类别标签不平衡

让我们看看数据集的不平衡

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)

这显示了正样本的小比例。

清理、拆分和规范化数据

原始数据存在一些问题。首先,TimeAmount 列变化太大,无法直接使用。删除 Time 列(因为其含义尚不清楚),并对 Amount 列取对数以缩小其范围。

cleaned_df = raw_df.copy()

# You don't want the `Time` column.
cleaned_df.pop('Time')

# The `Amount` column covers a huge range. Convert to log-space.
eps = 0.001 # 0 => 0.1¢
cleaned_df['Log Amount'] = np.log(cleaned_df.pop('Amount')+eps)

将数据集拆分为训练集、验证集和测试集。验证集在模型拟合过程中用于评估损失和任何指标,但模型不会使用此数据进行拟合。测试集在训练阶段完全未使用,仅在最后用于评估模型对新数据的泛化能力。这对于不平衡数据集尤其重要,因为在不平衡数据集上,由于缺乏训练数据,过拟合 是一个重大问题。

# Use a utility from sklearn to split and shuffle your dataset.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# Form np arrays of labels and features.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

我们检查三个集合中类别分布是否大致相同。

print(f'Average class probability in training set:   {train_labels.mean():.4f}')
print(f'Average class probability in validation set: {val_labels.mean():.4f}')
print(f'Average class probability in test set:       {test_labels.mean():.4f}')
Average class probability in training set:   0.0016
Average class probability in validation set: 0.0018
Average class probability in test set:       0.0019

鉴于正标签数量很少,这似乎是合理的。

使用 sklearn 的 StandardScaler 规范化输入特征。这将使均值变为 0,标准差变为 1。

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

查看数据分布

接下来,比较一些特征上正例和负例的分布。此时,你需要问自己一些问题

  • 这些分布有意义吗?
    • 是的。你已经规范化了输入,并且这些输入主要集中在 +/- 2 范围内。
  • 你能看到分布之间的差异吗?
    • 是的,正例包含更高的极值率。
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train_df.columns)

sns.jointplot(x=pos_df['V5'], y=pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(x=neg_df['V5'], y=neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")

png

png

定义模型和指标

定义一个函数,该函数创建一个简单的具有密集连接隐藏层、dropout 层(用于减少过拟合)和输出 sigmoid 层的神经网络,该层返回交易为欺诈的概率。

METRICS = [
      keras.metrics.BinaryCrossentropy(name='cross entropy'),  # same as model's loss
      keras.metrics.MeanSquaredError(name='Brier score'),
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

def make_model(metrics=METRICS, output_bias=None):
  if output_bias is not None:
    output_bias = tf.keras.initializers.Constant(output_bias)
  model = keras.Sequential([
      keras.layers.Dense(
          16, activation='relu',
          input_shape=(train_features.shape[-1],)),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
  ])

  model.compile(
      optimizer=keras.optimizers.Adam(learning_rate=1e-3),
      loss=keras.losses.BinaryCrossentropy(),
      metrics=metrics)

  return model

理解有用的指标

请注意,上面定义了一些指标,这些指标可以通过模型计算,在评估性能时将很有帮助。这些指标可以分为三组。

概率预测指标

当我们使用交叉熵作为损失函数训练网络时,它完全能够预测类别概率,即它是一个概率分类器。事实上,评估概率预测的良好指标是适当评分规则。它们的關鍵特性是预测真实概率是最优的。我们给出两个众所周知的例子

  • 交叉熵也称为对数损失
  • 均方误差也称为 Brier 分数

确定性 0/1 预测指标

最终,人们通常希望预测一个类别标签,0 或 1,无欺诈欺诈。这被称为确定性分类器。为了从我们的概率分类器中获得标签预测,需要选择一个概率阈值 \(t\)。默认情况下,如果预测概率大于 \(t=50\%\),则预测标签 1(欺诈),以下所有指标都隐式使用此默认值。

  • 错误的负例和错误的正例是错误分类的样本
  • 正确的负例和正确的正例是正确分类的样本
  • 准确率是正确分类的样本百分比 > \(\frac{\text{正确样本} }{\text{总样本} }\)
  • 精确率预测为正例且分类正确的样本百分比 > \(\frac{\text{正确正例} }{\text{正确正例 + 错误正例} }\)
  • 召回率实际为正例且分类正确的样本百分比 > \(\frac{\text{正确正例} }{\text{正确正例 + 错误负例} }\)

其他指标

以下指标考虑了所有可能的阈值 \(t\) 选择。

  • AUC 指的是接收者操作特征曲线 (ROC-AUC) 下的面积。此指标等于分类器将随机正样本排名高于随机负样本的概率。
  • AUPRC 指的是精确率-召回率曲线下的面积。此指标计算不同概率阈值的精确率-召回率对。

阅读更多

基线模型

构建模型

现在使用之前定义的函数创建和训练你的模型。请注意,模型使用大于默认批次大小的 2048 进行拟合,这对于确保每个批次都有一个合理的概率包含一些正样本非常重要。如果批次大小太小,它们可能没有欺诈交易可供学习。

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_prc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 16)                480       
                                                                 
 dropout (Dropout)           (None, 16)                0         
                                                                 
 dense_1 (Dense)             (None, 1)                 17        
                                                                 
=================================================================
Total params: 497 (1.94 KB)
Trainable params: 497 (1.94 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

测试运行模型

model.predict(train_features[:10])
1/1 [==============================] - 0s 471ms/step
array([[0.16263928],
       [0.35204744],
       [0.19377157],
       [0.72603256],
       [0.30116165],
       [0.25605297],
       [0.66053736],
       [0.31973222],
       [0.25077152],
       [0.26151225]], dtype=float32)

可选:设置正确的初始偏差。

这些初始猜测并不理想。你知道数据集是不平衡的。设置输出层的偏差以反映这一点,请参阅 训练神经网络的秘诀:“初始化良好”。这有助于初始收敛。

使用默认偏差初始化,损失应该约为 math.log(2) = 0.69314

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.4088

可以从以下公式推导出要设置的正确偏差

\[ p_0 = pos/(pos + neg) = 1/(1+e^{-b_0}) \]

\[ b_0 = -log_e(1/p_0 - 1) \]

\[ b_0 = log_e(pos/neg)\]

initial_bias = np.log([pos/neg])
initial_bias
array([-6.35935934])

将其设置为初始偏差,模型将给出更合理的初始猜测。

它应该接近:pos/total = 0.0018

model = make_model(output_bias=initial_bias)
model.predict(train_features[:10])
1/1 [==============================] - 0s 75ms/step
array([[0.00135984],
       [0.00134607],
       [0.00213977],
       [0.01406598],
       [0.0021732 ],
       [0.00640495],
       [0.00814889],
       [0.00254694],
       [0.00572464],
       [0.00216844]], dtype=float32)

使用此初始化,初始损失应该约为

\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))
Loss: 0.0087

此初始损失大约是使用朴素初始化时的 50 倍。

这样,模型就不需要花费前几个 epoch 来学习正例不太可能出现。它也使在训练期间更容易阅读损失图。

检查点初始权重

为了使各种训练运行更具可比性,请将此初始模型的权重保存在检查点文件中,并在训练之前将它们加载到每个模型中。

initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')
model.save_weights(initial_weights)

确认偏差修复是否有帮助

在继续之前,请快速确认仔细的偏差初始化是否真的有帮助。

使用和不使用这种仔细的初始化训练模型 20 个 epoch,并比较损失

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705458046.535087   10301 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
  # Use a log scale on y-axis to show the wide range of values.
  plt.semilogy(history.epoch, history.history['loss'],
               color=colors[n], label='Train ' + label)
  plt.semilogy(history.epoch, history.history['val_loss'],
               color=colors[n], label='Val ' + label,
               linestyle="--")
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

png

上图清楚地表明:就验证损失而言,在这个问题上,这种仔细的初始化具有明显的优势。

训练模型

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 2s 11ms/step - loss: 0.0109 - cross entropy: 0.0092 - Brier score: 0.0013 - tp: 179.0000 - fp: 128.0000 - tn: 227336.0000 - fn: 202.0000 - accuracy: 0.9986 - precision: 0.5831 - recall: 0.4698 - auc: 0.8759 - prc: 0.4240 - val_loss: 0.0053 - val_cross entropy: 0.0053 - val_Brier score: 7.6563e-04 - val_tp: 44.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 38.0000 - val_accuracy: 0.9991 - val_precision: 0.8980 - val_recall: 0.5366 - val_auc: 0.9188 - val_prc: 0.7535
Epoch 2/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0070 - cross entropy: 0.0070 - Brier score: 9.7767e-04 - tp: 137.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 162.0000 - accuracy: 0.9989 - precision: 0.8155 - recall: 0.4582 - auc: 0.8800 - prc: 0.5341 - val_loss: 0.0044 - val_cross entropy: 0.0044 - val_Brier score: 6.3545e-04 - val_tp: 54.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 28.0000 - val_accuracy: 0.9992 - val_precision: 0.8852 - val_recall: 0.6585 - val_auc: 0.9263 - val_prc: 0.7737
Epoch 3/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0061 - cross entropy: 0.0061 - Brier score: 8.9642e-04 - tp: 146.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 153.0000 - accuracy: 0.9990 - precision: 0.8156 - recall: 0.4883 - auc: 0.9033 - prc: 0.5771 - val_loss: 0.0040 - val_cross entropy: 0.0040 - val_Brier score: 6.0828e-04 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 27.0000 - val_accuracy: 0.9993 - val_precision: 0.9016 - val_recall: 0.6707 - val_auc: 0.9266 - val_prc: 0.7869
Epoch 4/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0057 - cross entropy: 0.0057 - Brier score: 8.9241e-04 - tp: 147.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 152.0000 - accuracy: 0.9990 - precision: 0.8305 - recall: 0.4916 - auc: 0.9045 - prc: 0.6121 - val_loss: 0.0037 - val_cross entropy: 0.0037 - val_Brier score: 5.6512e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.7996
Epoch 5/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0050 - cross entropy: 0.0050 - Brier score: 8.0944e-04 - tp: 163.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 136.0000 - accuracy: 0.9991 - precision: 0.8402 - recall: 0.5452 - auc: 0.9091 - prc: 0.6557 - val_loss: 0.0035 - val_cross entropy: 0.0035 - val_Brier score: 5.4862e-04 - val_tp: 58.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.8923 - val_recall: 0.7073 - val_auc: 0.9327 - val_prc: 0.8041
Epoch 6/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0046 - cross entropy: 0.0046 - Brier score: 7.5796e-04 - tp: 168.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 131.0000 - accuracy: 0.9991 - precision: 0.8615 - recall: 0.5619 - auc: 0.9214 - prc: 0.6995 - val_loss: 0.0034 - val_cross entropy: 0.0034 - val_Brier score: 5.3008e-04 - val_tp: 60.0000 - val_fp: 7.0000 - val_tn: 45480.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.8955 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8149
Epoch 7/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0045 - cross entropy: 0.0045 - Brier score: 7.0728e-04 - tp: 183.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8472 - recall: 0.6120 - auc: 0.9133 - prc: 0.6901 - val_loss: 0.0033 - val_cross entropy: 0.0033 - val_Brier score: 5.3596e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8227
Epoch 8/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0048 - cross entropy: 0.0048 - Brier score: 8.0575e-04 - tp: 169.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 130.0000 - accuracy: 0.9991 - precision: 0.8284 - recall: 0.5652 - auc: 0.9183 - prc: 0.6610 - val_loss: 0.0032 - val_cross entropy: 0.0032 - val_Brier score: 5.4781e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9389 - val_prc: 0.8321
Epoch 9/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4602e-04 - tp: 170.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 129.0000 - accuracy: 0.9991 - precision: 0.8629 - recall: 0.5686 - auc: 0.9186 - prc: 0.7075 - val_loss: 0.0031 - val_cross entropy: 0.0031 - val_Brier score: 5.1218e-04 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 24.0000 - val_accuracy: 0.9993 - val_precision: 0.9062 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8314
Epoch 10/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 6.7102e-04 - tp: 178.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 121.0000 - accuracy: 0.9992 - precision: 0.8768 - recall: 0.5953 - auc: 0.9203 - prc: 0.7351 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8812e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8293
Epoch 11/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.3325e-04 - tp: 191.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 108.0000 - accuracy: 0.9993 - precision: 0.8843 - recall: 0.6388 - auc: 0.9170 - prc: 0.7323 - val_loss: 0.0030 - val_cross entropy: 0.0030 - val_Brier score: 4.8228e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8301
Epoch 12/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.6081e-04 - tp: 173.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 126.0000 - accuracy: 0.9991 - precision: 0.8317 - recall: 0.5786 - auc: 0.9254 - prc: 0.7097 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7943e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8330
Epoch 13/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0043 - cross entropy: 0.0043 - Brier score: 7.4700e-04 - tp: 175.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 124.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.5853 - auc: 0.9238 - prc: 0.6897 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.7884e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8350
Epoch 14/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0040 - cross entropy: 0.0040 - Brier score: 7.1931e-04 - tp: 177.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 122.0000 - accuracy: 0.9992 - precision: 0.8551 - recall: 0.5920 - auc: 0.9171 - prc: 0.7144 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.8724e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8409
Epoch 15/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0042 - cross entropy: 0.0042 - Brier score: 7.5652e-04 - tp: 167.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 132.0000 - accuracy: 0.9991 - precision: 0.8608 - recall: 0.5585 - auc: 0.9238 - prc: 0.6964 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7200e-04 - val_tp: 66.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9167 - val_recall: 0.8049 - val_auc: 0.9388 - val_prc: 0.8410
Epoch 16/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.1767e-04 - tp: 177.0000 - fp: 34.0000 - tn: 181943.0000 - fn: 122.0000 - accuracy: 0.9991 - precision: 0.8389 - recall: 0.5920 - auc: 0.9239 - prc: 0.7223 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6891e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8418
Epoch 17/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0041 - cross entropy: 0.0041 - Brier score: 7.5757e-04 - tp: 166.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 133.0000 - accuracy: 0.9991 - precision: 0.8601 - recall: 0.5552 - auc: 0.9255 - prc: 0.7017 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6881e-04 - val_tp: 64.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9143 - val_recall: 0.7805 - val_auc: 0.9388 - val_prc: 0.8419
Epoch 18/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.7869e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7328 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.5812e-04 - val_tp: 67.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 15.0000 - val_accuracy: 0.9995 - val_precision: 0.9178 - val_recall: 0.8171 - val_auc: 0.9449 - val_prc: 0.8473
Epoch 19/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 6.9306e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9222 - prc: 0.7129 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 64.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 18.0000 - val_accuracy: 0.9995 - val_precision: 0.9275 - val_recall: 0.7805 - val_auc: 0.9389 - val_prc: 0.8439
Epoch 20/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5706e-04 - tp: 191.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 108.0000 - accuracy: 0.9992 - precision: 0.8604 - recall: 0.6388 - auc: 0.9240 - prc: 0.7368 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8355e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8442
Epoch 21/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0039 - cross entropy: 0.0039 - Brier score: 7.2760e-04 - tp: 180.0000 - fp: 33.0000 - tn: 181944.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8451 - recall: 0.6020 - auc: 0.9223 - prc: 0.7170 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 4.9822e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9388 - val_prc: 0.8441
Epoch 22/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.5225e-04 - tp: 181.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 118.0000 - accuracy: 0.9992 - precision: 0.8660 - recall: 0.6054 - auc: 0.9273 - prc: 0.7439 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8637e-04 - val_tp: 60.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9375 - val_recall: 0.7317 - val_auc: 0.9388 - val_prc: 0.8438
Epoch 23/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.5348e-04 - tp: 186.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8774 - recall: 0.6221 - auc: 0.9355 - prc: 0.7402 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6483e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8427
Epoch 24/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0037 - cross entropy: 0.0037 - Brier score: 6.7939e-04 - tp: 193.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8465 - recall: 0.6455 - auc: 0.9340 - prc: 0.7279 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1275e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8509
Epoch 25/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.7560e-04 - tp: 180.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 119.0000 - accuracy: 0.9992 - precision: 0.8654 - recall: 0.6020 - auc: 0.9290 - prc: 0.7396 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8990e-04 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9365 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8503
Epoch 26/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.4978e-04 - tp: 188.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 111.0000 - accuracy: 0.9992 - precision: 0.8704 - recall: 0.6288 - auc: 0.9307 - prc: 0.7594 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7567e-04 - val_tp: 63.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 19.0000 - val_accuracy: 0.9995 - val_precision: 0.9265 - val_recall: 0.7683 - val_auc: 0.9388 - val_prc: 0.8439
Epoch 27/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.1788e-04 - tp: 183.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8632 - recall: 0.6120 - auc: 0.9289 - prc: 0.7194 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6391e-04 - val_tp: 65.0000 - val_fp: 6.0000 - val_tn: 45481.0000 - val_fn: 17.0000 - val_accuracy: 0.9995 - val_precision: 0.9155 - val_recall: 0.7927 - val_auc: 0.9388 - val_prc: 0.8454
Epoch 28/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.2664e-04 - tp: 200.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 99.0000 - accuracy: 0.9993 - precision: 0.8734 - recall: 0.6689 - auc: 0.9306 - prc: 0.7426 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.1824e-04 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9355 - val_recall: 0.7073 - val_auc: 0.9388 - val_prc: 0.8435
Epoch 29/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 6.9012e-04 - tp: 185.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 114.0000 - accuracy: 0.9992 - precision: 0.8685 - recall: 0.6187 - auc: 0.9289 - prc: 0.7251 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7472e-04 - val_tp: 60.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 22.0000 - val_accuracy: 0.9994 - val_precision: 0.9231 - val_recall: 0.7317 - val_auc: 0.9449 - val_prc: 0.8510
Epoch 30/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.0173e-04 - tp: 197.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 102.0000 - accuracy: 0.9993 - precision: 0.8874 - recall: 0.6589 - auc: 0.9290 - prc: 0.7333 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8578e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8570
Epoch 31/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.9174e-04 - tp: 187.0000 - fp: 35.0000 - tn: 181942.0000 - fn: 112.0000 - accuracy: 0.9992 - precision: 0.8423 - recall: 0.6254 - auc: 0.9373 - prc: 0.7320 - val_loss: 0.0029 - val_cross entropy: 0.0029 - val_Brier score: 5.2550e-04 - val_tp: 58.0000 - val_fp: 3.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9508 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8546
Epoch 32/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6040e-04 - tp: 183.0000 - fp: 21.0000 - tn: 181956.0000 - fn: 116.0000 - accuracy: 0.9992 - precision: 0.8971 - recall: 0.6120 - auc: 0.9356 - prc: 0.7430 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9123e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8581
Epoch 33/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.3240e-04 - tp: 198.0000 - fp: 27.0000 - tn: 181950.0000 - fn: 101.0000 - accuracy: 0.9993 - precision: 0.8800 - recall: 0.6622 - auc: 0.9339 - prc: 0.7473 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 5.0220e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9510 - val_prc: 0.8544
Epoch 34/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.3294e-04 - tp: 193.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8694 - recall: 0.6455 - auc: 0.9373 - prc: 0.7536 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8504e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8489
Epoch 35/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.8906e-04 - tp: 184.0000 - fp: 29.0000 - tn: 181948.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8638 - recall: 0.6154 - auc: 0.9239 - prc: 0.7403 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9829e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8482
Epoch 36/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0036 - cross entropy: 0.0036 - Brier score: 6.5897e-04 - tp: 193.0000 - fp: 30.0000 - tn: 181947.0000 - fn: 106.0000 - accuracy: 0.9993 - precision: 0.8655 - recall: 0.6455 - auc: 0.9340 - prc: 0.7307 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9601e-04 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 24.0000 - val_accuracy: 0.9994 - val_precision: 0.9206 - val_recall: 0.7073 - val_auc: 0.9449 - val_prc: 0.8474
Epoch 37/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0038 - cross entropy: 0.0038 - Brier score: 7.0205e-04 - tp: 184.0000 - fp: 31.0000 - tn: 181946.0000 - fn: 115.0000 - accuracy: 0.9992 - precision: 0.8558 - recall: 0.6154 - auc: 0.9373 - prc: 0.7124 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.9088e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8476
Epoch 38/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4066e-04 - tp: 195.0000 - fp: 26.0000 - tn: 181951.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8824 - recall: 0.6522 - auc: 0.9374 - prc: 0.7656 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.8218e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8483
Epoch 39/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0034 - cross entropy: 0.0034 - Brier score: 6.2081e-04 - tp: 195.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 104.0000 - accuracy: 0.9993 - precision: 0.8744 - recall: 0.6522 - auc: 0.9274 - prc: 0.7673 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7334e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9449 - val_prc: 0.8511
Epoch 40/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.0397e-04 - tp: 202.0000 - fp: 28.0000 - tn: 181949.0000 - fn: 97.0000 - accuracy: 0.9993 - precision: 0.8783 - recall: 0.6756 - auc: 0.9358 - prc: 0.7739 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7153e-04 - val_tp: 62.0000 - val_fp: 4.0000 - val_tn: 45483.0000 - val_fn: 20.0000 - val_accuracy: 0.9995 - val_precision: 0.9394 - val_recall: 0.7561 - val_auc: 0.9449 - val_prc: 0.8499
Epoch 41/100
90/90 [==============================] - 0s 5ms/step - loss: 0.0035 - cross entropy: 0.0035 - Brier score: 6.6866e-04 - tp: 186.0000 - fp: 25.0000 - tn: 181952.0000 - fn: 113.0000 - accuracy: 0.9992 - precision: 0.8815 - recall: 0.6221 - auc: 0.9407 - prc: 0.7539 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.6169e-04 - val_tp: 66.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 16.0000 - val_accuracy: 0.9995 - val_precision: 0.9296 - val_recall: 0.8049 - val_auc: 0.9510 - val_prc: 0.8571
Epoch 42/100
86/90 [===========================>..] - ETA: 0s - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.4029e-04 - tp: 188.0000 - fp: 30.0000 - tn: 175806.0000 - fn: 104.0000 - accuracy: 0.9992 - precision: 0.8624 - recall: 0.6438 - auc: 0.9445 - prc: 0.7663Restoring model weights from the end of the best epoch: 32.
90/90 [==============================] - 0s 5ms/step - loss: 0.0033 - cross entropy: 0.0033 - Brier score: 6.3642e-04 - tp: 193.0000 - fp: 32.0000 - tn: 181945.0000 - fn: 106.0000 - accuracy: 0.9992 - precision: 0.8578 - recall: 0.6455 - auc: 0.9441 - prc: 0.7655 - val_loss: 0.0028 - val_cross entropy: 0.0028 - val_Brier score: 4.7751e-04 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45482.0000 - val_fn: 23.0000 - val_accuracy: 0.9994 - val_precision: 0.9219 - val_recall: 0.7195 - val_auc: 0.9510 - val_prc: 0.8563
Epoch 42: early stopping

检查训练历史

在本节中,你将生成模型在训练集和验证集上的准确率和损失图。这些图有助于检查过拟合,你可以在 过拟合和欠拟合 教程中了解更多关于过拟合的信息。

此外,你可以为上面创建的任何指标生成这些图。错误的负例作为示例包含在内。

def plot_metrics(history):
  metrics = ['loss', 'prc', 'precision', 'recall']
  for n, metric in enumerate(metrics):
    name = metric.replace("_"," ").capitalize()
    plt.subplot(2,2,n+1)
    plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
    plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    if metric == 'loss':
      plt.ylim([0, plt.ylim()[1]])
    elif metric == 'auc':
      plt.ylim([0.8,1])
    else:
      plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

png

评估指标

你可以使用 混淆矩阵 来总结实际标签与预测标签,其中 X 轴是预测标签,Y 轴是实际标签

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
def plot_cm(labels, predictions, threshold=0.5):
  cm = confusion_matrix(labels, predictions > threshold)
  plt.figure(figsize=(5,5))
  sns.heatmap(cm, annot=True, fmt="d")
  plt.title('Confusion matrix @{:.2f}'.format(threshold))
  plt.ylabel('Actual label')
  plt.xlabel('Predicted label')

  print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
  print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
  print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
  print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
  print('Total Fraudulent Transactions: ', np.sum(cm[1]))

在测试数据集上评估你的模型,并显示上面创建的指标的结果

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)
loss :  0.0038855739403516054
cross entropy :  0.0038855739403516054
Brier score :  0.0006162827485240996
tp :  81.0
fp :  11.0
tn :  56840.0
fn :  30.0
accuracy :  0.9992802143096924
precision :  0.8804348111152649
recall :  0.7297297120094299
auc :  0.9096326231956482
prc :  0.7863917350769043

Legitimate Transactions Detected (True Negatives):  56840
Legitimate Transactions Incorrectly Detected (False Positives):  11
Fraudulent Transactions Missed (False Negatives):  30
Fraudulent Transactions Detected (True Positives):  81
Total Fraudulent Transactions:  111

png

如果模型完美地预测了所有内容(真正的随机性是不可能的),这将是一个 对角矩阵,其中主对角线以外的值(表示错误预测)将为零。在这种情况下,矩阵显示你只有相对较少的错误正例,这意味着只有相对较少的合法交易被错误地标记为欺诈。

更改阈值

\(t=50\%\) 的默认阈值对应于错误负例和错误正例的成本相等。然而,在欺诈检测的情况下,你可能会将更高的成本与错误负例相关联,而不是与错误正例相关联。这种权衡可能是可取的,因为错误负例将允许欺诈交易通过,而错误正例可能会导致向客户发送电子邮件,要求他们验证其卡活动。

通过降低阈值,我们将更高的成本归因于错误负例,从而以更多错误正例为代价增加了遗漏的交易。我们测试了 10% 和 1% 的阈值。

plot_cm(test_labels, test_predictions_baseline, threshold=0.1)
plot_cm(test_labels, test_predictions_baseline, threshold=0.01)
Legitimate Transactions Detected (True Negatives):  56834
Legitimate Transactions Incorrectly Detected (False Positives):  17
Fraudulent Transactions Missed (False Negatives):  23
Fraudulent Transactions Detected (True Positives):  88
Total Fraudulent Transactions:  111
Legitimate Transactions Detected (True Negatives):  56806
Legitimate Transactions Incorrectly Detected (False Positives):  45
Fraudulent Transactions Missed (False Negatives):  22
Fraudulent Transactions Detected (True Positives):  89
Total Fraudulent Transactions:  111

png

png

绘制 ROC

现在绘制 ROC。此图很有用,因为它一目了然地显示了模型可以通过在整个范围内(0 到 1)调整输出阈值来达到的性能范围。因此,每个点对应于阈值的单个值。

def plot_roc(name, labels, predictions, **kwargs):
  fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

  plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  plt.xlabel('False positives [%]')
  plt.ylabel('True positives [%]')
  plt.xlim([-0.5,20])
  plt.ylim([80,100.5])
  plt.grid(True)
  ax = plt.gca()
  ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

绘制 PRC

现在绘制 AUPRC。通过绘制不同分类阈值的 (召回率,精确率) 点获得的插值精确率-召回率曲线下的面积。根据计算方式,PR AUC 可能等效于模型的平均精确率。

def plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Precision')
    plt.ylabel('Recall')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right');

png

看起来精确率相对较高,但召回率和 ROC 曲线下的面积 (AUC) 并不像你希望的那样高。分类器在尝试最大化精确率和召回率时经常面临挑战,这在处理不平衡数据集时尤其如此。重要的是要考虑你所关心的问题的背景下不同类型错误的成本。在本例中,错误负例(遗漏欺诈交易)可能会有财务成本,而错误正例(交易被错误地标记为欺诈)可能会降低用户满意度。

类别权重

计算类别权重

目标是识别欺诈交易,但你没有太多正样本可用,因此你需要让分类器对少数可用样本给予更高的权重。你可以通过参数传递 Keras 权重来实现这一点,每个类别对应一个权重。这将使模型“更加关注”来自样本不足类别的样本。但是,请注意,这不会以任何方式增加数据集的信息量。最终,使用类别权重或多或少等同于改变输出偏差或改变阈值。让我们看看它是如何运作的。

# Scaling by total/2 helps keep the loss to a similar magnitude.
# The sum of the weights of all examples stays the same.
weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))
Weight for class 0: 0.50
Weight for class 1: 289.44

使用类别权重训练模型

现在尝试使用类别权重重新训练和评估模型,看看它如何影响预测。

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 2s 11ms/step - loss: 0.9262 - cross entropy: 0.0166 - Brier score: 0.0027 - tp: 233.0000 - fp: 530.0000 - tn: 238298.0000 - fn: 177.0000 - accuracy: 0.9970 - precision: 0.3054 - recall: 0.5683 - auc: 0.8803 - prc: 0.4222 - val_loss: 0.0116 - val_cross entropy: 0.0116 - val_Brier score: 0.0011 - val_tp: 67.0000 - val_fp: 35.0000 - val_tn: 45452.0000 - val_fn: 15.0000 - val_accuracy: 0.9989 - val_precision: 0.6569 - val_recall: 0.8171 - val_auc: 0.9519 - val_prc: 0.7255
Epoch 2/100
90/90 [==============================] - 0s 5ms/step - loss: 0.6152 - cross entropy: 0.0319 - Brier score: 0.0059 - tp: 202.0000 - fp: 1117.0000 - tn: 180860.0000 - fn: 97.0000 - accuracy: 0.9933 - precision: 0.1531 - recall: 0.6756 - auc: 0.9051 - prc: 0.4410 - val_loss: 0.0172 - val_cross entropy: 0.0172 - val_Brier score: 0.0019 - val_tp: 69.0000 - val_fp: 72.0000 - val_tn: 45415.0000 - val_fn: 13.0000 - val_accuracy: 0.9981 - val_precision: 0.4894 - val_recall: 0.8415 - val_auc: 0.9577 - val_prc: 0.7220
Epoch 3/100
90/90 [==============================] - 0s 5ms/step - loss: 0.4397 - cross entropy: 0.0461 - Brier score: 0.0095 - tp: 229.0000 - fp: 1929.0000 - tn: 180048.0000 - fn: 70.0000 - accuracy: 0.9890 - precision: 0.1061 - recall: 0.7659 - auc: 0.9307 - prc: 0.4134 - val_loss: 0.0236 - val_cross entropy: 0.0236 - val_Brier score: 0.0029 - val_tp: 69.0000 - val_fp: 106.0000 - val_tn: 45381.0000 - val_fn: 13.0000 - val_accuracy: 0.9974 - val_precision: 0.3943 - val_recall: 0.8415 - val_auc: 0.9662 - val_prc: 0.7291
Epoch 4/100
90/90 [==============================] - 0s 5ms/step - loss: 0.4155 - cross entropy: 0.0619 - Brier score: 0.0136 - tp: 231.0000 - fp: 2898.0000 - tn: 179079.0000 - fn: 68.0000 - accuracy: 0.9837 - precision: 0.0738 - recall: 0.7726 - auc: 0.9272 - prc: 0.3804 - val_loss: 0.0319 - val_cross entropy: 0.0319 - val_Brier score: 0.0046 - val_tp: 70.0000 - val_fp: 188.0000 - val_tn: 45299.0000 - val_fn: 12.0000 - val_accuracy: 0.9956 - val_precision: 0.2713 - val_recall: 0.8537 - val_auc: 0.9697 - val_prc: 0.7095
Epoch 5/100
90/90 [==============================] - 0s 5ms/step - loss: 0.3247 - cross entropy: 0.0773 - Brier score: 0.0178 - tp: 241.0000 - fp: 3872.0000 - tn: 178105.0000 - fn: 58.0000 - accuracy: 0.9784 - precision: 0.0586 - recall: 0.8060 - auc: 0.9471 - prc: 0.3673 - val_loss: 0.0405 - val_cross entropy: 0.0405 - val_Brier score: 0.0068 - val_tp: 71.0000 - val_fp: 334.0000 - val_tn: 45153.0000 - val_fn: 11.0000 - val_accuracy: 0.9924 - val_precision: 0.1753 - val_recall: 0.8659 - val_auc: 0.9714 - val_prc: 0.6518
Epoch 6/100
90/90 [==============================] - 0s 6ms/step - loss: 0.3481 - cross entropy: 0.0976 - Brier score: 0.0225 - tp: 248.0000 - fp: 4880.0000 - tn: 177097.0000 - fn: 51.0000 - accuracy: 0.9729 - precision: 0.0484 - recall: 0.8294 - auc: 0.9351 - prc: 0.3069 - val_loss: 0.0494 - val_cross entropy: 0.0494 - val_Brier score: 0.0093 - val_tp: 73.0000 - val_fp: 511.0000 - val_tn: 44976.0000 - val_fn: 9.0000 - val_accuracy: 0.9886 - val_precision: 0.1250 - val_recall: 0.8902 - val_auc: 0.9742 - val_prc: 0.6313
Epoch 7/100
90/90 [==============================] - 0s 5ms/step - loss: 0.2719 - cross entropy: 0.1078 - Brier score: 0.0253 - tp: 257.0000 - fp: 5673.0000 - tn: 176304.0000 - fn: 42.0000 - accuracy: 0.9686 - precision: 0.0433 - recall: 0.8595 - auc: 0.9564 - prc: 0.2894 - val_loss: 0.0565 - val_cross entropy: 0.0565 - val_Brier score: 0.0112 - val_tp: 73.0000 - val_fp: 633.0000 - val_tn: 44854.0000 - val_fn: 9.0000 - val_accuracy: 0.9859 - val_precision: 0.1034 - val_recall: 0.8902 - val_auc: 0.9757 - val_prc: 0.6267
Epoch 8/100
90/90 [==============================] - 0s 6ms/step - loss: 0.2623 - cross entropy: 0.1179 - Brier score: 0.0275 - tp: 262.0000 - fp: 6123.0000 - tn: 175854.0000 - fn: 37.0000 - accuracy: 0.9662 - precision: 0.0410 - recall: 0.8763 - auc: 0.9554 - prc: 0.2609 - val_loss: 0.0607 - val_cross entropy: 0.0607 - val_Brier score: 0.0124 - val_tp: 73.0000 - val_fp: 686.0000 - val_tn: 44801.0000 - val_fn: 9.0000 - val_accuracy: 0.9847 - val_precision: 0.0962 - val_recall: 0.8902 - val_auc: 0.9754 - val_prc: 0.6069
Epoch 9/100
90/90 [==============================] - 0s 5ms/step - loss: 0.2915 - cross entropy: 0.1184 - Brier score: 0.0280 - tp: 257.0000 - fp: 6295.0000 - tn: 175682.0000 - fn: 42.0000 - accuracy: 0.9652 - precision: 0.0392 - recall: 0.8595 - auc: 0.9494 - prc: 0.2652 - val_loss: 0.0653 - val_cross entropy: 0.0653 - val_Brier score: 0.0135 - val_tp: 74.0000 - val_fp: 742.0000 - val_tn: 44745.0000 - val_fn: 8.0000 - val_accuracy: 0.9835 - val_precision: 0.0907 - val_recall: 0.9024 - val_auc: 0.9773 - val_prc: 0.5856
Epoch 10/100
90/90 [==============================] - 0s 6ms/step - loss: 0.2632 - cross entropy: 0.1336 - Brier score: 0.0313 - tp: 259.0000 - fp: 6976.0000 - tn: 175001.0000 - fn: 40.0000 - accuracy: 0.9615 - precision: 0.0358 - recall: 0.8662 - auc: 0.9561 - prc: 0.2365 - val_loss: 0.0700 - val_cross entropy: 0.0700 - val_Brier score: 0.0146 - val_tp: 76.0000 - val_fp: 801.0000 - val_tn: 44686.0000 - val_fn: 6.0000 - val_accuracy: 0.9823 - val_precision: 0.0867 - val_recall: 0.9268 - val_auc: 0.9773 - val_prc: 0.5876
Epoch 11/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2336 - cross entropy: 0.1282 - Brier score: 0.0299 - tp: 269.0000 - fp: 6690.0000 - tn: 175287.0000 - fn: 30.0000 - accuracy: 0.9631 - precision: 0.0387 - recall: 0.8997 - auc: 0.9586 - prc: 0.2494 - val_loss: 0.0679 - val_cross entropy: 0.0679 - val_Brier score: 0.0140 - val_tp: 76.0000 - val_fp: 757.0000 - val_tn: 44730.0000 - val_fn: 6.0000 - val_accuracy: 0.9833 - val_precision: 0.0912 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.5891
Epoch 12/100
90/90 [==============================] - 1s 6ms/step - loss: 0.2399 - cross entropy: 0.1289 - Brier score: 0.0298 - tp: 265.0000 - fp: 6654.0000 - tn: 175323.0000 - fn: 34.0000 - accuracy: 0.9633 - precision: 0.0383 - recall: 0.8863 - auc: 0.9602 - prc: 0.2601 - val_loss: 0.0684 - val_cross entropy: 0.0684 - val_Brier score: 0.0141 - val_tp: 76.0000 - val_fp: 762.0000 - val_tn: 44725.0000 - val_fn: 6.0000 - val_accuracy: 0.9831 - val_precision: 0.0907 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.5848
Epoch 13/100
79/90 [=========================>....] - ETA: 0s - loss: 0.2286 - cross entropy: 0.1265 - Brier score: 0.0295 - tp: 237.0000 - fp: 5838.0000 - tn: 155684.0000 - fn: 33.0000 - accuracy: 0.9637 - precision: 0.0390 - recall: 0.8778 - auc: 0.9696 - prc: 0.2645Restoring model weights from the end of the best epoch: 3.
90/90 [==============================] - 1s 6ms/step - loss: 0.2341 - cross entropy: 0.1275 - Brier score: 0.0297 - tp: 262.0000 - fp: 6631.0000 - tn: 175346.0000 - fn: 37.0000 - accuracy: 0.9634 - precision: 0.0380 - recall: 0.8763 - auc: 0.9665 - prc: 0.2538 - val_loss: 0.0757 - val_cross entropy: 0.0757 - val_Brier score: 0.0159 - val_tp: 76.0000 - val_fp: 834.0000 - val_tn: 44653.0000 - val_fn: 6.0000 - val_accuracy: 0.9816 - val_precision: 0.0835 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.5709
Epoch 13: early stopping

检查训练历史

plot_metrics(weighted_history)

png

评估指标

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
  print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)
loss :  0.024716919288039207
cross entropy :  0.024716919288039207
Brier score :  0.0029473488684743643
tp :  88.0
fp :  134.0
tn :  56717.0
fn :  23.0
accuracy :  0.9972437620162964
precision :  0.3963963985443115
recall :  0.792792797088623
auc :  0.9477326273918152
prc :  0.6732124090194702

Legitimate Transactions Detected (True Negatives):  56717
Legitimate Transactions Incorrectly Detected (False Positives):  134
Fraudulent Transactions Missed (False Negatives):  23
Fraudulent Transactions Detected (True Positives):  88
Total Fraudulent Transactions:  111

png

在这里你可以看到,使用类别权重,准确率和精确率较低,因为有更多假阳性,但相反,召回率和 AUC 较高,因为模型也找到了更多真阳性。尽管准确率较低,但该模型具有更高的召回率(并且在阈值为 50% 时比基线模型识别出更多欺诈交易)。当然,两种类型的错误都有成本(你也不希望通过将太多合法交易标记为欺诈来打扰用户)。仔细考虑这些不同类型的错误对你应用的权衡。

与改变阈值的基线模型相比,类别加权模型明显较差。基线模型的优越性通过较低的测试损失值(交叉熵和均方误差)得到进一步证实,此外还可以通过将两个模型的 ROC 曲线一起绘制来观察。

绘制 ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

绘制 PRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right');

png

过采样

过采样少数类

另一种相关的方法是对数据集进行重采样,对少数类进行过采样。

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

使用 NumPy

你可以通过从正样本中选择正确的随机索引数量来手动平衡数据集

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181977, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363954, 29)

使用 tf.data

如果你正在使用 tf.data,生成平衡样本的最简单方法是从一个 positive 和一个 negative 数据集开始,并将它们合并。有关更多示例,请参阅 tf.data 指南

BUFFER_SIZE = 100000

def make_ds(features, labels):
  ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
  ds = ds.shuffle(BUFFER_SIZE).repeat()
  return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

每个数据集都提供 (feature, label)

for features, label in pos_ds.take(1):
  print("Features:\n", features.numpy())
  print()
  print("Label: ", label.numpy())
Features:
 [ 4.57437149e-03  1.41282803e+00 -1.70738347e+00  7.86145002e-01
  2.34322123e+00 -1.32760854e+00  1.68238195e+00 -7.10272314e-01
  8.18760297e-01 -3.09684905e+00  2.01295966e+00 -3.98984767e+00
  1.02827419e+00 -5.00000000e+00 -1.25820263e+00  1.91494135e+00
  5.00000000e+00  3.32009026e+00 -2.75342824e+00 -8.47588695e-03
 -7.83382558e-01 -1.24259811e+00 -6.45039879e-01 -1.71393384e-02
  1.13211907e+00 -1.52256293e+00 -1.08919872e+00 -1.06657977e+00
 -1.45889491e+00]

Label:  1

使用 tf.data.Dataset.sample_from_datasets 将两者合并在一起

resampled_ds = tf.data.Dataset.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
  print(label.numpy().mean())
0.50341796875

要使用此数据集,你需要每个 epoch 的步数。

在这种情况下,“epoch”的定义不太清楚。假设它是看到每个负样本一次所需的批次数量

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch
278.0

在过采样数据上训练

现在尝试使用重采样数据集而不是使用类别权重来训练模型,看看这些方法如何比较。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks=[early_stopping],
    validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 8s 22ms/step - loss: 0.3612 - cross entropy: 0.3306 - Brier score: 0.1096 - tp: 258749.0000 - fp: 76958.0000 - tn: 264513.0000 - fn: 26086.0000 - accuracy: 0.8355 - precision: 0.7708 - recall: 0.9084 - auc: 0.9490 - prc: 0.9536 - val_loss: 0.2021 - val_cross entropy: 0.2021 - val_Brier score: 0.0446 - val_tp: 75.0000 - val_fp: 1144.0000 - val_tn: 44343.0000 - val_fn: 7.0000 - val_accuracy: 0.9747 - val_precision: 0.0615 - val_recall: 0.9146 - val_auc: 0.9741 - val_prc: 0.7919
Epoch 2/100
278/278 [==============================] - 5s 20ms/step - loss: 0.1757 - cross entropy: 0.1757 - Brier score: 0.0515 - tp: 262968.0000 - fp: 15885.0000 - tn: 269124.0000 - fn: 21367.0000 - accuracy: 0.9346 - precision: 0.9430 - recall: 0.9249 - auc: 0.9817 - prc: 0.9852 - val_loss: 0.1003 - val_cross entropy: 0.1003 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 858.0000 - val_tn: 44629.0000 - val_fn: 6.0000 - val_accuracy: 0.9810 - val_precision: 0.0814 - val_recall: 0.9268 - val_auc: 0.9777 - val_prc: 0.7702
Epoch 3/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1358 - cross entropy: 0.1358 - Brier score: 0.0398 - tp: 266700.0000 - fp: 11006.0000 - tn: 273451.0000 - fn: 18187.0000 - accuracy: 0.9487 - precision: 0.9604 - recall: 0.9362 - auc: 0.9891 - prc: 0.9904 - val_loss: 0.0725 - val_cross entropy: 0.0725 - val_Brier score: 0.0158 - val_tp: 76.0000 - val_fp: 790.0000 - val_tn: 44697.0000 - val_fn: 6.0000 - val_accuracy: 0.9825 - val_precision: 0.0878 - val_recall: 0.9268 - val_auc: 0.9766 - val_prc: 0.7553
Epoch 4/100
278/278 [==============================] - 6s 21ms/step - loss: 0.1151 - cross entropy: 0.1151 - Brier score: 0.0341 - tp: 269190.0000 - fp: 9719.0000 - tn: 274441.0000 - fn: 15994.0000 - accuracy: 0.9548 - precision: 0.9652 - recall: 0.9439 - auc: 0.9925 - prc: 0.9930 - val_loss: 0.0596 - val_cross entropy: 0.0596 - val_Brier score: 0.0136 - val_tp: 76.0000 - val_fp: 725.0000 - val_tn: 44762.0000 - val_fn: 6.0000 - val_accuracy: 0.9840 - val_precision: 0.0949 - val_recall: 0.9268 - val_auc: 0.9726 - val_prc: 0.7292
Epoch 5/100
278/278 [==============================] - 6s 20ms/step - loss: 0.1006 - cross entropy: 0.1006 - Brier score: 0.0299 - tp: 270949.0000 - fp: 8853.0000 - tn: 275916.0000 - fn: 13626.0000 - accuracy: 0.9605 - precision: 0.9684 - recall: 0.9521 - auc: 0.9945 - prc: 0.9946 - val_loss: 0.0525 - val_cross entropy: 0.0525 - val_Brier score: 0.0124 - val_tp: 76.0000 - val_fp: 668.0000 - val_tn: 44819.0000 - val_fn: 6.0000 - val_accuracy: 0.9852 - val_precision: 0.1022 - val_recall: 0.9268 - val_auc: 0.9717 - val_prc: 0.7216
Epoch 6/100
278/278 [==============================] - 6s 20ms/step - loss: 0.0904 - cross entropy: 0.0904 - Brier score: 0.0268 - tp: 272681.0000 - fp: 8122.0000 - tn: 276344.0000 - fn: 12197.0000 - accuracy: 0.9643 - precision: 0.9711 - recall: 0.9572 - auc: 0.9958 - prc: 0.9956 - val_loss: 0.0456 - val_cross entropy: 0.0456 - val_Brier score: 0.0108 - val_tp: 76.0000 - val_fp: 576.0000 - val_tn: 44911.0000 - val_fn: 6.0000 - val_accuracy: 0.9872 - val_precision: 0.1166 - val_recall: 0.9268 - val_auc: 0.9737 - val_prc: 0.7304
Epoch 7/100
278/278 [==============================] - 6s 20ms/step - loss: 0.0828 - cross entropy: 0.0828 - Brier score: 0.0244 - tp: 273911.0000 - fp: 7426.0000 - tn: 277008.0000 - fn: 10999.0000 - accuracy: 0.9676 - precision: 0.9736 - recall: 0.9614 - auc: 0.9965 - prc: 0.9963 - val_loss: 0.0408 - val_cross entropy: 0.0408 - val_Brier score: 0.0099 - val_tp: 77.0000 - val_fp: 546.0000 - val_tn: 44941.0000 - val_fn: 5.0000 - val_accuracy: 0.9879 - val_precision: 0.1236 - val_recall: 0.9390 - val_auc: 0.9752 - val_prc: 0.7232
Epoch 8/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0775 - cross entropy: 0.0775 - Brier score: 0.0228 - tp: 274985.0000 - fp: 6904.0000 - tn: 277146.0000 - fn: 10309.0000 - accuracy: 0.9698 - precision: 0.9755 - recall: 0.9639 - auc: 0.9970 - prc: 0.9968 - val_loss: 0.0387 - val_cross entropy: 0.0387 - val_Brier score: 0.0096 - val_tp: 77.0000 - val_fp: 568.0000 - val_tn: 44919.0000 - val_fn: 5.0000 - val_accuracy: 0.9874 - val_precision: 0.1194 - val_recall: 0.9390 - val_auc: 0.9761 - val_prc: 0.7145
Epoch 9/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0743 - cross entropy: 0.0743 - Brier score: 0.0219 - tp: 274086.0000 - fp: 6704.0000 - tn: 278828.0000 - fn: 9726.0000 - accuracy: 0.9711 - precision: 0.9761 - recall: 0.9657 - auc: 0.9971 - prc: 0.9969 - val_loss: 0.0344 - val_cross entropy: 0.0344 - val_Brier score: 0.0085 - val_tp: 76.0000 - val_fp: 492.0000 - val_tn: 44995.0000 - val_fn: 6.0000 - val_accuracy: 0.9891 - val_precision: 0.1338 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7147
Epoch 10/100
278/278 [==============================] - 5s 20ms/step - loss: 0.0712 - cross entropy: 0.0712 - Brier score: 0.0211 - tp: 275221.0000 - fp: 6399.0000 - tn: 278199.0000 - fn: 9525.0000 - accuracy: 0.9720 - precision: 0.9773 - recall: 0.9665 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0311 - val_cross entropy: 0.0311 - val_Brier score: 0.0077 - val_tp: 76.0000 - val_fp: 434.0000 - val_tn: 45053.0000 - val_fn: 6.0000 - val_accuracy: 0.9903 - val_precision: 0.1490 - val_recall: 0.9268 - val_auc: 0.9772 - val_prc: 0.7140
Epoch 11/100
276/278 [============================>.] - ETA: 0s - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 273841.0000 - fp: 6329.0000 - tn: 275888.0000 - fn: 9190.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970Restoring model weights from the end of the best epoch: 1.
278/278 [==============================] - 5s 20ms/step - loss: 0.0695 - cross entropy: 0.0695 - Brier score: 0.0206 - tp: 275842.0000 - fp: 6384.0000 - tn: 277849.0000 - fn: 9269.0000 - accuracy: 0.9725 - precision: 0.9774 - recall: 0.9675 - auc: 0.9973 - prc: 0.9970 - val_loss: 0.0302 - val_cross entropy: 0.0302 - val_Brier score: 0.0075 - val_tp: 76.0000 - val_fp: 433.0000 - val_tn: 45054.0000 - val_fn: 6.0000 - val_accuracy: 0.9904 - val_precision: 0.1493 - val_recall: 0.9268 - val_auc: 0.9775 - val_prc: 0.7154
Epoch 11: early stopping

如果训练过程在每次梯度更新时都考虑整个数据集,那么这种过采样将与类别加权基本相同。

但是,当你像这里一样分批训练模型时,过采样数据提供了更平滑的梯度信号:每个正样本不是在一个批次中以较大的权重显示,而是在许多不同的批次中以较小的权重显示。

这种更平滑的梯度信号使训练模型变得更容易。

检查训练历史

请注意,指标的分布在这里将不同,因为训练数据的分布与验证和测试数据完全不同。

plot_metrics(resampled_history)

png

重新训练

由于在平衡数据上训练更容易,因此上述训练过程可能会很快过拟合。

因此,将 epoch 分开,以便 tf.keras.callbacks.EarlyStopping 对何时停止训练有更精细的控制。

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# Reset the bias to zero, since this dataset is balanced.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # These are not real epochs
    steps_per_epoch=20,
    epochs=10*EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_ds))
Epoch 1/1000
20/20 [==============================] - 2s 47ms/step - loss: 0.6826 - cross entropy: 0.3390 - Brier score: 0.1176 - tp: 18430.0000 - fp: 14493.0000 - tn: 51299.0000 - fn: 2307.0000 - accuracy: 0.8058 - precision: 0.5598 - recall: 0.8887 - auc: 0.9464 - prc: 0.8794 - val_loss: 1.0004 - val_cross entropy: 1.0004 - val_Brier score: 0.3825 - val_tp: 79.0000 - val_fp: 36434.0000 - val_tn: 9053.0000 - val_fn: 3.0000 - val_accuracy: 0.2004 - val_precision: 0.0022 - val_recall: 0.9634 - val_auc: 0.9261 - val_prc: 0.5749
Epoch 2/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.5885 - cross entropy: 0.5885 - Brier score: 0.2089 - tp: 18504.0000 - fp: 12626.0000 - tn: 7862.0000 - fn: 1968.0000 - accuracy: 0.6437 - precision: 0.5944 - recall: 0.9039 - auc: 0.8752 - prc: 0.9121 - val_loss: 0.8348 - val_cross entropy: 0.8348 - val_Brier score: 0.3117 - val_tp: 79.0000 - val_fp: 28884.0000 - val_tn: 16603.0000 - val_fn: 3.0000 - val_accuracy: 0.3661 - val_precision: 0.0027 - val_recall: 0.9634 - val_auc: 0.9395 - val_prc: 0.6709
Epoch 3/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.5070 - cross entropy: 0.5070 - Brier score: 0.1790 - tp: 18418.0000 - fp: 10425.0000 - tn: 10193.0000 - fn: 1924.0000 - accuracy: 0.6985 - precision: 0.6386 - recall: 0.9054 - auc: 0.8991 - prc: 0.9280 - val_loss: 0.6975 - val_cross entropy: 0.6975 - val_Brier score: 0.2495 - val_tp: 78.0000 - val_fp: 19535.0000 - val_tn: 25952.0000 - val_fn: 4.0000 - val_accuracy: 0.5712 - val_precision: 0.0040 - val_recall: 0.9512 - val_auc: 0.9499 - val_prc: 0.7048
Epoch 4/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.4413 - cross entropy: 0.4413 - Brier score: 0.1530 - tp: 18483.0000 - fp: 8228.0000 - tn: 12349.0000 - fn: 1900.0000 - accuracy: 0.7527 - precision: 0.6920 - recall: 0.9068 - auc: 0.9179 - prc: 0.9406 - val_loss: 0.5893 - val_cross entropy: 0.5893 - val_Brier score: 0.1998 - val_tp: 77.0000 - val_fp: 11782.0000 - val_tn: 33705.0000 - val_fn: 5.0000 - val_accuracy: 0.7413 - val_precision: 0.0065 - val_recall: 0.9390 - val_auc: 0.9552 - val_prc: 0.7246
Epoch 5/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3914 - cross entropy: 0.3914 - Brier score: 0.1335 - tp: 18615.0000 - fp: 6548.0000 - tn: 13896.0000 - fn: 1901.0000 - accuracy: 0.7937 - precision: 0.7398 - recall: 0.9073 - auc: 0.9304 - prc: 0.9500 - val_loss: 0.5045 - val_cross entropy: 0.5045 - val_Brier score: 0.1613 - val_tp: 77.0000 - val_fp: 7135.0000 - val_tn: 38352.0000 - val_fn: 5.0000 - val_accuracy: 0.8433 - val_precision: 0.0107 - val_recall: 0.9390 - val_auc: 0.9595 - val_prc: 0.7424
Epoch 6/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.3563 - cross entropy: 0.3563 - Brier score: 0.1183 - tp: 18429.0000 - fp: 5050.0000 - tn: 15533.0000 - fn: 1948.0000 - accuracy: 0.8292 - precision: 0.7849 - recall: 0.9044 - auc: 0.9391 - prc: 0.9552 - val_loss: 0.4395 - val_cross entropy: 0.4395 - val_Brier score: 0.1328 - val_tp: 77.0000 - val_fp: 4727.0000 - val_tn: 40760.0000 - val_fn: 5.0000 - val_accuracy: 0.8962 - val_precision: 0.0160 - val_recall: 0.9390 - val_auc: 0.9616 - val_prc: 0.7625
Epoch 7/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3220 - cross entropy: 0.3220 - Brier score: 0.1047 - tp: 18807.0000 - fp: 4065.0000 - tn: 16241.0000 - fn: 1847.0000 - accuracy: 0.8557 - precision: 0.8223 - recall: 0.9106 - auc: 0.9485 - prc: 0.9631 - val_loss: 0.3867 - val_cross entropy: 0.3867 - val_Brier score: 0.1105 - val_tp: 77.0000 - val_fp: 3192.0000 - val_tn: 42295.0000 - val_fn: 5.0000 - val_accuracy: 0.9298 - val_precision: 0.0236 - val_recall: 0.9390 - val_auc: 0.9635 - val_prc: 0.7711
Epoch 8/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.3012 - cross entropy: 0.3012 - Brier score: 0.0959 - tp: 18607.0000 - fp: 3384.0000 - tn: 17165.0000 - fn: 1804.0000 - accuracy: 0.8733 - precision: 0.8461 - recall: 0.9116 - auc: 0.9545 - prc: 0.9661 - val_loss: 0.3438 - val_cross entropy: 0.3438 - val_Brier score: 0.0932 - val_tp: 77.0000 - val_fp: 2361.0000 - val_tn: 43126.0000 - val_fn: 5.0000 - val_accuracy: 0.9481 - val_precision: 0.0316 - val_recall: 0.9390 - val_auc: 0.9644 - val_prc: 0.7748
Epoch 9/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2795 - cross entropy: 0.2795 - Brier score: 0.0880 - tp: 18636.0000 - fp: 2891.0000 - tn: 17616.0000 - fn: 1817.0000 - accuracy: 0.8851 - precision: 0.8657 - recall: 0.9112 - auc: 0.9589 - prc: 0.9692 - val_loss: 0.3087 - val_cross entropy: 0.3087 - val_Brier score: 0.0799 - val_tp: 76.0000 - val_fp: 1892.0000 - val_tn: 43595.0000 - val_fn: 6.0000 - val_accuracy: 0.9583 - val_precision: 0.0386 - val_recall: 0.9268 - val_auc: 0.9658 - val_prc: 0.7797
Epoch 10/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2620 - cross entropy: 0.2620 - Brier score: 0.0812 - tp: 18743.0000 - fp: 2432.0000 - tn: 17955.0000 - fn: 1830.0000 - accuracy: 0.8959 - precision: 0.8851 - recall: 0.9110 - auc: 0.9625 - prc: 0.9724 - val_loss: 0.2798 - val_cross entropy: 0.2798 - val_Brier score: 0.0695 - val_tp: 76.0000 - val_fp: 1615.0000 - val_tn: 43872.0000 - val_fn: 6.0000 - val_accuracy: 0.9644 - val_precision: 0.0449 - val_recall: 0.9268 - val_auc: 0.9674 - val_prc: 0.7834
Epoch 11/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2480 - cross entropy: 0.2480 - Brier score: 0.0757 - tp: 18645.0000 - fp: 2154.0000 - tn: 18383.0000 - fn: 1778.0000 - accuracy: 0.9040 - precision: 0.8964 - recall: 0.9129 - auc: 0.9668 - prc: 0.9748 - val_loss: 0.2551 - val_cross entropy: 0.2551 - val_Brier score: 0.0611 - val_tp: 76.0000 - val_fp: 1428.0000 - val_tn: 44059.0000 - val_fn: 6.0000 - val_accuracy: 0.9685 - val_precision: 0.0505 - val_recall: 0.9268 - val_auc: 0.9691 - val_prc: 0.7858
Epoch 12/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2368 - cross entropy: 0.2368 - Brier score: 0.0722 - tp: 18706.0000 - fp: 1922.0000 - tn: 18565.0000 - fn: 1767.0000 - accuracy: 0.9099 - precision: 0.9068 - recall: 0.9137 - auc: 0.9682 - prc: 0.9759 - val_loss: 0.2341 - val_cross entropy: 0.2341 - val_Brier score: 0.0543 - val_tp: 75.0000 - val_fp: 1301.0000 - val_tn: 44186.0000 - val_fn: 7.0000 - val_accuracy: 0.9713 - val_precision: 0.0545 - val_recall: 0.9146 - val_auc: 0.9710 - val_prc: 0.7888
Epoch 13/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2223 - cross entropy: 0.2223 - Brier score: 0.0667 - tp: 18874.0000 - fp: 1694.0000 - tn: 18675.0000 - fn: 1717.0000 - accuracy: 0.9167 - precision: 0.9176 - recall: 0.9166 - auc: 0.9720 - prc: 0.9785 - val_loss: 0.2162 - val_cross entropy: 0.2162 - val_Brier score: 0.0488 - val_tp: 75.0000 - val_fp: 1235.0000 - val_tn: 44252.0000 - val_fn: 7.0000 - val_accuracy: 0.9727 - val_precision: 0.0573 - val_recall: 0.9146 - val_auc: 0.9732 - val_prc: 0.7912
Epoch 14/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.2172 - cross entropy: 0.2172 - Brier score: 0.0648 - tp: 18681.0000 - fp: 1627.0000 - tn: 18898.0000 - fn: 1754.0000 - accuracy: 0.9175 - precision: 0.9199 - recall: 0.9142 - auc: 0.9732 - prc: 0.9789 - val_loss: 0.2011 - val_cross entropy: 0.2011 - val_Brier score: 0.0444 - val_tp: 75.0000 - val_fp: 1167.0000 - val_tn: 44320.0000 - val_fn: 7.0000 - val_accuracy: 0.9742 - val_precision: 0.0604 - val_recall: 0.9146 - val_auc: 0.9748 - val_prc: 0.7927
Epoch 15/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2088 - cross entropy: 0.2088 - Brier score: 0.0619 - tp: 18878.0000 - fp: 1484.0000 - tn: 18949.0000 - fn: 1649.0000 - accuracy: 0.9235 - precision: 0.9271 - recall: 0.9197 - auc: 0.9749 - prc: 0.9806 - val_loss: 0.1872 - val_cross entropy: 0.1872 - val_Brier score: 0.0405 - val_tp: 75.0000 - val_fp: 1100.0000 - val_tn: 44387.0000 - val_fn: 7.0000 - val_accuracy: 0.9757 - val_precision: 0.0638 - val_recall: 0.9146 - val_auc: 0.9760 - val_prc: 0.7931
Epoch 16/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.2011 - cross entropy: 0.2011 - Brier score: 0.0596 - tp: 18797.0000 - fp: 1439.0000 - tn: 19068.0000 - fn: 1656.0000 - accuracy: 0.9244 - precision: 0.9289 - recall: 0.9190 - auc: 0.9768 - prc: 0.9818 - val_loss: 0.1743 - val_cross entropy: 0.1743 - val_Brier score: 0.0369 - val_tp: 75.0000 - val_fp: 1029.0000 - val_tn: 44458.0000 - val_fn: 7.0000 - val_accuracy: 0.9773 - val_precision: 0.0679 - val_recall: 0.9146 - val_auc: 0.9769 - val_prc: 0.7935
Epoch 17/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1961 - cross entropy: 0.1961 - Brier score: 0.0575 - tp: 18762.0000 - fp: 1337.0000 - tn: 19238.0000 - fn: 1623.0000 - accuracy: 0.9277 - precision: 0.9335 - recall: 0.9204 - auc: 0.9773 - prc: 0.9821 - val_loss: 0.1636 - val_cross entropy: 0.1636 - val_Brier score: 0.0342 - val_tp: 75.0000 - val_fp: 997.0000 - val_tn: 44490.0000 - val_fn: 7.0000 - val_accuracy: 0.9780 - val_precision: 0.0700 - val_recall: 0.9146 - val_auc: 0.9777 - val_prc: 0.7943
Epoch 18/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1891 - cross entropy: 0.1891 - Brier score: 0.0554 - tp: 18751.0000 - fp: 1286.0000 - tn: 19292.0000 - fn: 1631.0000 - accuracy: 0.9288 - precision: 0.9358 - recall: 0.9200 - auc: 0.9789 - prc: 0.9833 - val_loss: 0.1544 - val_cross entropy: 0.1544 - val_Brier score: 0.0320 - val_tp: 75.0000 - val_fp: 981.0000 - val_tn: 44506.0000 - val_fn: 7.0000 - val_accuracy: 0.9783 - val_precision: 0.0710 - val_recall: 0.9146 - val_auc: 0.9780 - val_prc: 0.7971
Epoch 19/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1833 - cross entropy: 0.1833 - Brier score: 0.0534 - tp: 18789.0000 - fp: 1144.0000 - tn: 19432.0000 - fn: 1595.0000 - accuracy: 0.9331 - precision: 0.9426 - recall: 0.9218 - auc: 0.9802 - prc: 0.9842 - val_loss: 0.1461 - val_cross entropy: 0.1461 - val_Brier score: 0.0300 - val_tp: 76.0000 - val_fp: 949.0000 - val_tn: 44538.0000 - val_fn: 6.0000 - val_accuracy: 0.9790 - val_precision: 0.0741 - val_recall: 0.9268 - val_auc: 0.9782 - val_prc: 0.7972
Epoch 20/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1775 - cross entropy: 0.1775 - Brier score: 0.0517 - tp: 18845.0000 - fp: 1120.0000 - tn: 19463.0000 - fn: 1532.0000 - accuracy: 0.9353 - precision: 0.9439 - recall: 0.9248 - auc: 0.9814 - prc: 0.9849 - val_loss: 0.1394 - val_cross entropy: 0.1394 - val_Brier score: 0.0287 - val_tp: 76.0000 - val_fp: 969.0000 - val_tn: 44518.0000 - val_fn: 6.0000 - val_accuracy: 0.9786 - val_precision: 0.0727 - val_recall: 0.9268 - val_auc: 0.9788 - val_prc: 0.7971
Epoch 21/1000
20/20 [==============================] - 1s 26ms/step - loss: 0.1727 - cross entropy: 0.1727 - Brier score: 0.0506 - tp: 19042.0000 - fp: 1056.0000 - tn: 19310.0000 - fn: 1552.0000 - accuracy: 0.9363 - precision: 0.9475 - recall: 0.9246 - auc: 0.9818 - prc: 0.9855 - val_loss: 0.1331 - val_cross entropy: 0.1331 - val_Brier score: 0.0274 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9789 - val_prc: 0.7973
Epoch 22/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1711 - cross entropy: 0.1711 - Brier score: 0.0501 - tp: 19041.0000 - fp: 1102.0000 - tn: 19283.0000 - fn: 1534.0000 - accuracy: 0.9356 - precision: 0.9453 - recall: 0.9254 - auc: 0.9826 - prc: 0.9859 - val_loss: 0.1275 - val_cross entropy: 0.1275 - val_Brier score: 0.0262 - val_tp: 76.0000 - val_fp: 965.0000 - val_tn: 44522.0000 - val_fn: 6.0000 - val_accuracy: 0.9787 - val_precision: 0.0730 - val_recall: 0.9268 - val_auc: 0.9784 - val_prc: 0.7879
Epoch 23/1000
20/20 [==============================] - 1s 27ms/step - loss: 0.1657 - cross entropy: 0.1657 - Brier score: 0.0479 - tp: 19074.0000 - fp: 1045.0000 - tn: 19372.0000 - fn: 1469.0000 - accuracy: 0.9386 - precision: 0.9481 - recall: 0.9285 - auc: 0.9838 - prc: 0.9867 - val_loss: 0.1215 - val_cross entropy: 0.1215 - val_Brier score: 0.0249 - val_tp: 76.0000 - val_fp: 939.0000 - val_tn: 44548.0000 - val_fn: 6.0000 - val_accuracy: 0.9793 - val_precision: 0.0749 - val_recall: 0.9268 - val_auc: 0.9785 - val_prc: 0.7882
Epoch 24/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1631 - cross entropy: 0.1631 - Brier score: 0.0478 - tp: 19006.0000 - fp: 1055.0000 - tn: 19442.0000 - fn: 1457.0000 - accuracy: 0.9387 - precision: 0.9474 - recall: 0.9288 - auc: 0.9839 - prc: 0.9868 - val_loss: 0.1166 - val_cross entropy: 0.1166 - val_Brier score: 0.0239 - val_tp: 76.0000 - val_fp: 924.0000 - val_tn: 44563.0000 - val_fn: 6.0000 - val_accuracy: 0.9796 - val_precision: 0.0760 - val_recall: 0.9268 - val_auc: 0.9780 - val_prc: 0.7886
Epoch 25/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1586 - cross entropy: 0.1586 - Brier score: 0.0464 - tp: 19058.0000 - fp: 971.0000 - tn: 19476.0000 - fn: 1455.0000 - accuracy: 0.9408 - precision: 0.9515 - recall: 0.9291 - auc: 0.9847 - prc: 0.9875 - val_loss: 0.1119 - val_cross entropy: 0.1119 - val_Brier score: 0.0229 - val_tp: 76.0000 - val_fp: 908.0000 - val_tn: 44579.0000 - val_fn: 6.0000 - val_accuracy: 0.9799 - val_precision: 0.0772 - val_recall: 0.9268 - val_auc: 0.9783 - val_prc: 0.7886
Epoch 26/1000
20/20 [==============================] - 0s 26ms/step - loss: 0.1568 - cross entropy: 0.1568 - Brier score: 0.0459 - tp: 18807.0000 - fp: 974.0000 - tn: 19740.0000 - fn: 1439.0000 - accuracy: 0.9411 - precision: 0.9508 - recall: 0.9289 - auc: 0.9851 - prc: 0.9874 - val_loss: 0.1072 - val_cross entropy: 0.1072 - val_Brier score: 0.0219 - val_tp: 76.0000 - val_fp: 881.0000 - val_tn: 44606.0000 - val_fn: 6.0000 - val_accuracy: 0.9805 - val_precision: 0.0794 - val_recall: 0.9268 - val_auc: 0.9779 - val_prc: 0.7889
Epoch 27/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1562 - cross entropy: 0.1562 - Brier score: 0.0457 - tp: 19045.0000 - fp: 1010.0000 - tn: 19477.0000 - fn: 1428.0000 - accuracy: 0.9405 - precision: 0.9496 - recall: 0.9302 - auc: 0.9854 - prc: 0.9876 - val_loss: 0.1032 - val_cross entropy: 0.1032 - val_Brier score: 0.0211 - val_tp: 76.0000 - val_fp: 864.0000 - val_tn: 44623.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0809 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7704
Epoch 28/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1525 - cross entropy: 0.1525 - Brier score: 0.0442 - tp: 19016.0000 - fp: 881.0000 - tn: 19650.0000 - fn: 1413.0000 - accuracy: 0.9440 - precision: 0.9557 - recall: 0.9308 - auc: 0.9862 - prc: 0.9882 - val_loss: 0.0998 - val_cross entropy: 0.0998 - val_Brier score: 0.0205 - val_tp: 76.0000 - val_fp: 866.0000 - val_tn: 44621.0000 - val_fn: 6.0000 - val_accuracy: 0.9809 - val_precision: 0.0807 - val_recall: 0.9268 - val_auc: 0.9778 - val_prc: 0.7706
Epoch 29/1000
20/20 [==============================] - 0s 24ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0429 - tp: 19105.0000 - fp: 852.0000 - tn: 19596.0000 - fn: 1407.0000 - accuracy: 0.9448 - precision: 0.9573 - recall: 0.9314 - auc: 0.9870 - prc: 0.9891 - val_loss: 0.0968 - val_cross entropy: 0.0968 - val_Brier score: 0.0200 - val_tp: 76.0000 - val_fp: 868.0000 - val_tn: 44619.0000 - val_fn: 6.0000 - val_accuracy: 0.9808 - val_precision: 0.0805 - val_recall: 0.9268 - val_auc: 0.9770 - val_prc: 0.7709
Epoch 30/1000
20/20 [==============================] - 0s 25ms/step - loss: 0.1465 - cross entropy: 0.1465 - Brier score: 0.0431 - tp: 19112.0000 - fp: 860.0000 - tn: 19584.0000 - fn: 1404.0000 - accuracy: 0.9447 - precision: 0.9569 - recall: 0.9316 - auc: 0.9867 - prc: 0.9888 - val_loss: 0.0941 - val_cross entropy: 0.0941 - val_Brier score: 0.0195 - val_tp: 76.0000 - val_fp: 850.0000 - val_tn: 44637.0000 - val_fn: 6.0000 - val_accuracy: 0.9812 - val_precision: 0.0821 - val_recall: 0.9268 - val_auc: 0.9774 - val_prc: 0.7712
Epoch 31/1000
20/20 [==============================] - ETA: 0s - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893Restoring model weights from the end of the best epoch: 21.
20/20 [==============================] - 0s 25ms/step - loss: 0.1436 - cross entropy: 0.1436 - Brier score: 0.0420 - tp: 19077.0000 - fp: 857.0000 - tn: 19655.0000 - fn: 1371.0000 - accuracy: 0.9456 - precision: 0.9570 - recall: 0.9330 - auc: 0.9876 - prc: 0.9893 - val_loss: 0.0912 - val_cross entropy: 0.0912 - val_Brier score: 0.0189 - val_tp: 76.0000 - val_fp: 826.0000 - val_tn: 44661.0000 - val_fn: 6.0000 - val_accuracy: 0.9817 - val_precision: 0.0843 - val_recall: 0.9268 - val_auc: 0.9767 - val_prc: 0.7622
Epoch 31: early stopping

重新检查训练历史

plot_metrics(resampled_history)

png

评估指标

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
90/90 [==============================] - 0s 1ms/step
28/28 [==============================] - 0s 1ms/step
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
  print(name, ': ', value)
print()
plot_cm(test_labels, test_predictions_resampled)
loss :  0.13269135355949402
cross entropy :  0.13269135355949402
Brier score :  0.02699681930243969
tp :  96.0
fp :  1177.0
tn :  55674.0
fn :  15.0
accuracy :  0.9790737628936768
precision :  0.07541241496801376
recall :  0.8648648858070374
auc :  0.9722627401351929
prc :  0.703483521938324

Legitimate Transactions Detected (True Negatives):  55674
Legitimate Transactions Incorrectly Detected (False Positives):  1177
Fraudulent Transactions Missed (False Negatives):  15
Fraudulent Transactions Detected (True Positives):  96
Total Fraudulent Transactions:  111

png

绘制 ROC

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')
plot_roc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

绘制 AUPRC

plot_prc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_prc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_prc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_prc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_prc("Train Resampled", train_labels, train_predictions_resampled, color=colors[2])
plot_prc("Test Resampled", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')
plt.legend(loc='lower right');

png

将本教程应用于你的问题

不平衡数据分类是一项固有的困难任务,因为可供学习的样本很少。你应该始终先从数据开始,尽力收集尽可能多的样本,并认真思考哪些特征可能相关,以便模型能够充分利用你的少数类。在某些时候,你的模型可能会难以改进并产生你想要的结果,因此重要的是要牢记问题的背景以及不同类型错误之间的权衡。