XGBoost机器学习模型的决策过程
IT行业相对于一般传统行业,发展更新速度更快,一旦停止了学习,很快就会被行业所淘汰。所以我们需要踏踏实实的不断学习,精进自己的技术,尤其是初学者。今天golang学习网给大家整理了《XGBoost机器学习模型的决策过程》,聊聊,我们一起来看看吧!

使用 XGBoost 的算法在 Kaggle 和其它数据科学竞赛中经常可以获得好成绩,因此受到了人们的欢迎。本文用一个具体的数据集分析了 XGBoost 机器学习模型的预测过程,通过使用可视化手段展示结果,我们可以更好地理解模型的预测过程。
随着机器学习的产业应用不断发展,理解、解释和定义机器学习模型的工作原理似乎已成日益明显的趋势。对于非深度学习类型的机器学习分类问题,XGBoost 是最流行的库。由于 XGBoost 可以很好地扩展到大型数据集中,并支持多种语言,它在商业化环境中特别有用。例如,使用 XGBoost 可以很容易地在 Python 中训练模型,并把模型部署到 Java 产品环境中。
虽然 XGBoost 可以达到很高的准确率,但对于 XGBoost 如何进行决策而达到如此高的准确率的过程,还是不够透明。当直接将结果移交给客户的时候,这种不透明可能是很严重的缺陷。理解事情发生的原因是很有用的。那些转向应用机器学习理解数据的公司,同样需要理解来自模型的预测。这一点变得越来越重要。例如,谁也不希望信贷机构使用机器学习模型预测用户的信誉,却无法解释做出这些预测的过程。
另一个例子是,如果我们的机器学习模型说,一个婚姻档案和一个出生档案是和同一个人相关的(档案关联任务),但档案上的日期暗示这桩婚姻的双方分别是一个很老的人和一个很年轻的人,我们可能会质疑为什么模型会将它们关联起来。在诸如这样的例子中,理解模型做出这样的预测的原因是非常有价值的。其结果可能是模型考虑了名字和位置的独特性,并做出了正确的预测。但也可能是模型的特征并没有正确考虑档案上的年龄差距。在这个案例中,对模型预测的理解可以帮助我们寻找提升模型性能的方法。
在这篇文章中,我们将介绍一些技术以更好地理解 XGBoost 的预测过程。这允许我们在利用 gradient boosting 的威力的同时,仍然能理解模型的决策过程。
为了解释这些技术,我们将使用 Titanic 数据集。该数据集有每个泰坦尼克号乘客的信息(包括乘客是否生还)。我们的目标是预测一个乘客是否生还,并且理解做出该预测的过程。即使是使用这些数据,我们也能看到理解模型决策的重要性。想象一下,假如我们有一个关于最近发生的船难的乘客数据集。建立这样的预测模型的目的实际上并不在于预测结果本身,但理解预测过程可以帮助我们学习如何最大化意外中的生还者。
import pandas as pd
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import operator
import matplotlib.pyplot as plt
import seaborn as sns
import lime.lime_tabular
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Imputer
import numpy as np
from sklearn.grid_search import GridSearchCV
%matplotlib inline
我们要做的首件事是观察我们的数据,你可以在 Kaggle 上找到(https://www.kaggle.com/c/titanic/data)这个数据集。拿到数据集之后,我们会对数据进行简单的清理。即:
- 清除名字和乘客 ID
- 把分类变量转化为虚拟变量
- 用中位数填充和去除数据
这些清洗技巧非常简单,本文的目标不是讨论数据清洗,而是解释 XGBoost,因此这些都是快速、合理的清洗以使模型获得训练。
data = pd.read_csv("./data/titantic/train.csv")
y = data.Survived
X = data.drop(["Survived", "Name", "PassengerId"], 1)
X = pd.get_dummies(X)
现在让我们将数据集分为训练集和测试集。
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
并通过少量的超参数测试构建一个训练管道。
pipeline = Pipeline(
[('imputer', Imputer(strategy='median')),
('model', XGBClassifier())])
parameters = dict(model__max_depth=[3, 5, 7],
model__learning_rate=[.01, .1],
model__n_estimators=[100, 500])
cv = GridSearchCV(pipeline, param_grid=parameters)
cv.fit(X_train, y_train)
接着查看测试结果。为简单起见,我们将会使用与 Kaggle 相同的指标:准确率。
test_predictions = cv.predict(X_test)
print("Test Accuracy: {}".format(
accuracy_score(y_test, test_predictions)))
Test Accuracy: 0.8101694915254237
至此我们得到了一个还不错的准确率,在 Kaggle 的大约 9000 个竞争者中排到了前 500 名。因此我们还有进一步提升的空间,但在此将作为留给读者的练习。
我们继续关于理解模型学习到什么的讨论。常用的方法是使用 XGBoost 提供的特征重要性(feature importance)。特征重要性的级别越高,表示该特征对改善模型预测的贡献越大。接下来我们将使用重要性参数对特征进行分级,并比较相对重要性。
fi = list(zip(X.columns, cv.best_estimator_.named_steps['model'].feature_importances_))
fi.sort(key = operator.itemgetter(1), reverse=True)
top_10 = fi[:10]
x = [x[0] for x in top_10]
y = [x[1] for x in top_10]
top_10_chart = sns.barplot(x, y)
plt.setp(top_10_chart.get_xticklabels(), rotation=90)

从上图可以看出,票价和年龄是很重要的特征。我们可以进一步查看生还/遇难与票价的相关分布:
sns.barplot(y_train, X_train['Fare'])

我们可以很清楚地看到,那些生还者相比遇难者的平均票价要高得多,因此把票价当成重要特征可能是合理的。
特征重要性可能是理解一般的特征重要性的不错方法。假如出现了这样的特例,即模型预测一个高票价的乘客无法获得生还,则我们可以得出高票价并不必然导致生还,接下来我们将分析可能导致模型得出该乘客无法生还的其它特征。
这种个体层次上的分析对于生产式机器学习系统可能非常有用。考虑其它例子,使用模型预测是否可以某人一项贷款。我们知道信用评分将是模型的一个很重要的特征,但是却出现了一个拥有高信用评分却被模型拒绝的客户,这时我们将如何向客户做出解释?又该如何向管理者解释?
幸运的是,近期出现了华盛顿大学关于解释任意分类器的预测过程的研究。他们的方法称为 LIME,已经在 GitHub 上开源(https://github.com/marcotcr/lime)。本文不打算对此展开讨论,可以参见论文(https://arxiv.org/pdf/1602.04938.pdf)
接下来我们尝试在模型中应用 LIME。基本上,首先需要定义一个处理训练数据的解释器(我们需要确保传递给解释器的估算训练数据集正是将要训练的数据集):
X_train_imputed = cv.best_estimator_.named_steps['imputer'].transform(X_train)
explainer = lime.lime_tabular.LimeTabularExplainer(X_train_imputed,
feature_names=X_train.columns.tolist(),
class_names=["Not Survived", "Survived"],
discretize_continuous=True)
随后你必须定义一个函数,它以特征数组为变量,并返回一个数组和每个类的概率:
model = cv.best_estimator_.named_steps['model']
def xgb_prediction(X_array_in):
if len(X_array_in.shape) 2:
X_array_in = np.expand_dims(X_array_in, 0)
return model.predict_proba(X_array_in)
最后,我们传递一个示例,让解释器使用你的函数输出特征数和标签:
X_test_imputed = cv.best_estimator_.named_steps['imputer'].transform(X_test)
exp = explainer.explain_instance(
X_test_imputed[1],
xgb_prediction,
num_features=5,
top_labels=1)
exp.show_in_notebook(show_table=True,
show_all=False)

在这里我们有一个示例,76% 的可能性是不存活的。我们还想看看哪个特征对于哪个类贡献最大,重要性又如何。例如,在 Sex = Female 时,生存几率更大。让我们看看柱状图:
sns.barplot(X_train['Sex_female'], y_train)

所以这看起来很有道理。如果你是女性,这就大大提高了你在训练数据中存活的几率。所以为什么预测结果是「未存活」?看起来 Pclass =2.0 大大降低了存活率。让我们看看:
sns.barplot(X_train['Pclass'], y_train)

看起来 Pclass 等于 2 的存活率还是比较低的,所以我们对于自己的预测结果有了更多的理解。看看 LIME 上展示的 top5 特征,看起来这个人似乎仍然能活下来,让我们看看它的标签:
y_test.values[0]>>>1
这个人确实活下来了,所以我们的模型有错!感谢 LIME,我们可以对问题原因有一些认识:看起来 Pclass 可能需要被抛弃。这种方式可以帮助我们,希望能够找到一些改进模型的方法。
本文为读者提供了一个简单有效理解 XGBoost 的方法。希望这些方法可以帮助你合理利用 XGBoost,让你的模型能够做出更好的推断。
今天关于《XGBoost机器学习模型的决策过程》的内容就介绍到这里了,是不是学起来一目了然!想要了解更多关于机器学习,XGBoost的内容请关注golang学习网公众号!
跑ChatGPT体量模型,从此只需一块GPU:加速百倍的方法来了
- 上一篇
- 跑ChatGPT体量模型,从此只需一块GPU:加速百倍的方法来了
- 下一篇
- 大模型如何可靠?IBM等学者最新《基础模型的基础鲁棒性》教程
-
- 魔幻的嚓茶
- 赞 ??,一直没懂这个问题,但其实工作中常常有遇到...不过今天到这,帮助很大,总算是懂了,感谢师傅分享博文!
- 2023-04-25 11:55:45
-
- 舒服的水杯
- 这篇文章出现的刚刚好,好细啊,写的不错,已收藏,关注师傅了!希望师傅能多写科技周边相关的文章。
- 2023-04-24 04:29:03
-
- 懦弱的帽子
- 这篇技术贴出现的刚刚好,作者大大加油!
- 2023-04-18 20:50:56
-
- 英俊的玉米
- 太详细了,码起来,感谢楼主的这篇文章内容,我会继续支持!
- 2023-04-16 06:12:42
-
- 科技周边 · 人工智能 | 4天前 | AI绘画
- AI绘画工具安装与配置教程
- 339浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 海螺AI语音功能测评与体验分享
- 260浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- ChatGPT读不了加密PDF?先解密再上传
- 438浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 千问AI测试规范与覆盖率提升技巧
- 152浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- MiniMaxMusic2.0专业模式上线:音乐创作新神器
- 232浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 即梦AI音乐可视化效果评测
- 280浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 | 豆包AI 豆包AI助手
- 豆包AI写诗技巧与教程分享
- 152浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 | openclaw
- OpenClawAI摘要生成技巧全解析
- 102浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 百度发布DuMate智能体,李彦宏解读DAA新定义
- 247浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 智谱清影制作鸟瞰街景镜头教程
- 306浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 | openclaw
- OpenClaw框架解析与技术亮点揭秘
- 357浏览 收藏
-
- 科技周边 · 人工智能 | 4天前 |
- 即梦AI美妆详情页提示词技巧
- 334浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 500次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 485次学习
-
- ChatExcel酷表
- ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
- 6186次使用
-
- Any绘本
- 探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
- 6592次使用
-
- 可赞AI
- 可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
- 6396次使用
-
- 星月写作
- 星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
- 8361次使用
-
- MagicLight
- MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
- 7010次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览

