在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看 | 下载笔记本 |
背景
此笔记本演示了如何在 Jupyter/Colab 环境中使用 Scikit-Learn 模型,使用模型卡工具包生成模型卡。您可以在 https://modelcards.withgoogle.com/about 了解更多关于模型卡的信息。
设置
首先,我们需要安装并导入必要的软件包。
升级到 Pip 20.2 并安装软件包
pip install --upgrade pip==21.3
pip install -U seaborn scikit-learn model-card-toolkit
您是否重启了运行时?
如果您使用的是 Google Colab,则第一次运行上面的单元格时,必须重启运行时(运行时 > 重启运行时...)。
导入软件包
我们导入必要的软件包,包括 Scikit-Learn。
from datetime import date
from io import BytesIO
from IPython import display
import model_card_toolkit as mctlib
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_roc_curve, plot_confusion_matrix
import base64
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import uuid
加载数据
此示例使用 Scikit-Learn 可以使用 load_breast_cancer() 函数加载的乳腺癌威斯康星诊断数据集。
cancer = load_breast_cancer()
X = pd.DataFrame(cancer.data, columns=cancer.feature_names)
y = pd.Series(cancer.target)
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train.head()
y_train.head()
17 0 117 0 195 1 337 0 509 0 dtype: int64
绘制数据
我们将从数据中创建几个绘图,并将它们包含在模型卡中。
# Utility function that will export a plot to a base-64 encoded string that the model card will accept.
def plot_to_str():
img = BytesIO()
plt.savefig(img, format='png')
return base64.encodebytes(img.getvalue()).decode('utf-8')
# Plot the mean radius feature for both the train and test sets
sns.displot(x=X_train['mean radius'], hue=y_train)
mean_radius_train = plot_to_str()
sns.displot(x=X_test['mean radius'], hue=y_test)
mean_radius_test = plot_to_str()
# Plot the mean texture feature for both the train and test sets
sns.displot(x=X_train['mean texture'], hue=y_train)
mean_texture_train = plot_to_str()
sns.displot(x=X_test['mean texture'], hue=y_test)
mean_texture_test = plot_to_str()
训练模型
# Create a classifier and fit the training data
clf = GradientBoostingClassifier().fit(X_train, y_train)
评估模型
# Plot a ROC curve
plot_roc_curve(clf, X_test, y_test)
roc_curve = plot_to_str()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_roc_curve is deprecated; Function :func:`plot_roc_curve` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: :meth:`sklearn.metric.RocCurveDisplay.from_predictions` or :meth:`sklearn.metric.RocCurveDisplay.from_estimator`. warnings.warn(msg, category=FutureWarning)
# Plot a confusion matrix
plot_confusion_matrix(clf, X_test, y_test)
confusion_matrix = plot_to_str()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator. warnings.warn(msg, category=FutureWarning)
创建模型卡
初始化工具包和模型卡
mct = mctlib.ModelCardToolkit()
model_card = mct.scaffold_assets()
将信息注释到模型卡中
model_card.model_details.name = 'Breast Cancer Wisconsin (Diagnostic) Dataset'
model_card.model_details.overview = (
'This model predicts whether breast cancer is benign or malignant based on '
'image measurements.')
model_card.model_details.owners = [
mctlib.Owner(name= 'Model Cards Team', contact='[email protected]')
]
model_card.model_details.references = [
mctlib.Reference(reference='https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)'),
mctlib.Reference(reference='https://minds.wisconsin.edu/bitstream/handle/1793/59692/TR1131.pdf')
]
model_card.model_details.version.name = str(uuid.uuid4())
model_card.model_details.version.date = str(date.today())
model_card.considerations.ethical_considerations = [mctlib.Risk(
name=('Manual selection of image sections to digitize could create '
'selection bias'),
mitigation_strategy='Automate the selection process'
)]
model_card.considerations.limitations = [mctlib.Limitation(description='Breast cancer diagnosis')]
model_card.considerations.use_cases = [mctlib.UseCase(description='Breast cancer diagnosis')]
model_card.considerations.users = [mctlib.User(description='Medical professionals'), mctlib.User(description='ML researchers')]
model_card.model_parameters.data.append(mctlib.Dataset())
model_card.model_parameters.data[0].graphics.description = (
f'{len(X_train)} rows with {len(X_train.columns)} features')
model_card.model_parameters.data[0].graphics.collection = [
mctlib.Graphic(image=mean_radius_train),
mctlib.Graphic(image=mean_texture_train)
]
model_card.model_parameters.data.append(mctlib.Dataset())
model_card.model_parameters.data[1].graphics.description = (
f'{len(X_test)} rows with {len(X_test.columns)} features')
model_card.model_parameters.data[1].graphics.collection = [
mctlib.Graphic(image=mean_radius_test),
mctlib.Graphic(image=mean_texture_test)
]
model_card.quantitative_analysis.graphics.description = (
'ROC curve and confusion matrix')
model_card.quantitative_analysis.graphics.collection = [
mctlib.Graphic(image=roc_curve),
mctlib.Graphic(image=confusion_matrix)
]
mct.update_model_card(model_card)
生成模型卡
# Return the model card document as an HTML page
html = mct.export_format()
display.display(display.HTML(html))