XGBoost
extreme gradient boosting, 是gradient boosting machine的优化实现,快速有效。
- xgboost简介
- xgboost特点
- xgboost基本使用指南
- xgboost理论基础
- supervise learning
- CART
- boosting
- gradient boosting
- xgboost
- xgboost实战
- 特征工程
- 参数调优
1 - 初识xgboost
- 简介
- 优势
- 实战
1.1 - 简介
xgboost是gradient boosting machine的C++优化实现,gradient boosting machine的含义:
- machine:机器学习模型, 对数据产生的规律进行建模
- boosting machine:是一种弱学习器组合成强学习器的模型
- gradient boosting machine:根据梯度下降的方式组合弱学习器
1.1.1 - machine
对数据进行产生规律进行建模的问题通常会形式化为一个最小化目标函数的问题,目标函数通常有两个部分组成:
目标函数:
- 损失函数(选择与训练数据匹配最好的模型)
- 回归: (y^?y)2 ( y ^ ? y ) 2
- 分类: 0-1损失,logistic损失,合页损失,指数损失
- 正则项(选择最简单的模型)
- L2正则
- L1正则
1.1.2 - boosting machine
- boosting: 将弱学习器组合成强学习器
- 弱学习器:决策树/分类回归树(xgboost的弱学习器)
- 决策树:叶子结点对应决策
- 分类回归树:叶子节点对应预测分数
1.1.3 - gradient boosting machine
经典的boosting machine算法是Adaboost,adaboost的损失函数是指数损失, friedman将adaboost推广到一般的gradient boosting框架,得到gradient boosting machine:将boosting视为一个数值优化的问题,采用类似于梯度下降的方式来进行求解,这样可以使用任何可微的损失函数,支持的任务从两类分类扩宽到多类分类等。
1.2 - xgboost优势
特点:
- 正则化:标准的GBM(gradient boosting machine)没有显式的正则化
- 并行
- 自定义优化目标和评价准则:需要损失函数的一阶导数和二阶导数
- 剪枝:当新增分裂带来的是负收益的时候,GBM会停止分裂,xgboost会一直分裂到最大的深度,然后剪枝
- 支持在线学习
- 在结构化数据上表现突出,深度学习在非机构化的数据上表现好。
1.3 - 实战xgboost
1.3.1 - 处理数据科学任务的一般流程
1.3.2 - 基于sklearn框架的xgboost使用
from sklearn.datasets import load_svmlight_file
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier
数据读取
file_path = "./data/"
X_train, y_train = load_svmlight_file(file_path+"agaricus.txt.train")
X_test, y_test = load_svmlight_file(file_path+"agaricus.txt.test")
print(X_train.shape, y_train.shape)
(6513, 126) (6513,)
print(X_test.shape, y_test.shape)
(1611, 126) (1611,)
参数介绍:
- max_depth: 树的最大深度。缺省值为6
- learning_rate:为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重。 learning_rate通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为:[0,1]
- slient:取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0
- objective: 定义学习任务及相应的学习目标,“binary:logistic” 表示二分类的逻辑回归问题,输出为概率
配置模型
xgbc = XGBClassifier(max_depth=2, learning_rate=1, n_estimators=2, # number of iterations or number of treesslient=0,objective="binary:logistic")
训练模型
xgbc.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=1, max_delta_step=0,max_depth=2, min_child_weight=1, missing=None, n_estimators=2,n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, slient=0, subsample=1)
训练误差
pred_train = xgbc.predict(X_train)
pred_train = [round(x) for x in pred_train]
train_score = accuracy_score(y_train, pred_train)
print("Train Accuracy: %.2f%%" % (train_score * 100))
Train Accuracy: 97.77%
测试误差
pred_test = xgbc.predict(X_test)
pred_test = [1 if x >= 0.5 else 0 for x in pred_test]
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, pred_test) * 100))
Test Accuracy: 97.83%
1.3.3 - 验证集
将训练数据的一部分留出来,不参与模型参数训练。留出来的这部分
数据称为验证集(validation set)
- 余下的数据训练模型,训练好的模型在验证集上测试
- 校验集上的性能可视为模型在未知数据上性能的估计,选择在校验集上表现最好的模型
from sklearn.model_selection import train_test_split
划分训练集和验证集
file_path = "./data/"
X_train, y_train = load_svmlight_file(file_path+"agaricus.txt.train")
X_test, y_test = load_svmlight_file(file_path+"agaricus.txt.test")X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size= 0.33, random_state=42)
print(X_train.shape, y_train.shape)
print(X_validation.shape, y_validation.shape)
(4363, 126) (4363,)
(2150, 126) (2150,)
xgbc = XGBClassifier(max_depth=2, learning_rate=1, n_estimators=2, slient=False, objective="binary:logistic")
xgbc.fit(X_train, y_train, verbose=True)
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=1, max_delta_step=0,max_depth=2, min_child_weight=1, missing=None, n_estimators=2,n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, slient=False, subsample=1)
# performance in validation set
pred_val = xgbc.predict(X_validation)
pred_val = [round(x) for x in pred_val]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_val, y_validation) * 100))
Validation Accuracy: 97%
# performance in train set
pred_train = xgbc.predict(X_train)
pred_train = [round(x) for x in pred_train]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_train, y_train) * 100))
Validation Accuracy: 98%
# performance in test set
pred_test = xgbc.predict(X_test)
pred_test = [round(x) for x in pred_test]
print("Validation Accuracy: %2.f%%" % (accuracy_score(pred_test, y_test) * 100))
Validation Accuracy: 98%
1.3.4 - 学习曲线-关于弱分类器的个数或者说迭代的次数
import matplotlib.pyplot as plt
n_iteration = 100xgbc = XGBClassifier(max_depth=2, learning_rate=0.1, n_estimators=n_iteration, objective="binary:logistic")
eval_set = [(X_train, y_train), (X_validation, y_validation)]
xgbc.fit(X_train, y_train, eval_set=eval_set, eval_metric=["error", "logloss"], verbose=True)
[0] validation_0-error:0.044236 validation_0-logloss:0.614162 validation_1-error:0.051163 validation_1-logloss:0.615457
[1] validation_0-error:0.039193 validation_0-logloss:0.549179 validation_1-error:0.046512 validation_1-logloss:0.551203
[2] validation_0-error:0.044236 validation_0-logloss:0.494366 validation_1-error:0.051163 validation_1-logloss:0.497442
[3] validation_0-error:0.039193 validation_0-logloss:0.447845 validation_1-error:0.046512 validation_1-logloss:0.451486
[4] validation_0-error:0.039193 validation_0-logloss:0.407646 validation_1-error:0.046512 validation_1-logloss:0.411989
[5] validation_0-error:0.039193 validation_0-logloss:0.371941 validation_1-error:0.046512 validation_1-logloss:0.377037
[6] validation_0-error:0.022003 validation_0-logloss:0.341067 validation_1-error:0.026047 validation_1-logloss:0.346286
[7] validation_0-error:0.039193 validation_0-logloss:0.313232 validation_1-error:0.046512 validation_1-logloss:0.319077
[8] validation_0-error:0.039193 validation_0-logloss:0.288775 validation_1-error:0.046512 validation_1-logloss:0.294526
[9] validation_0-error:0.022003 validation_0-logloss:0.267046 validation_1-error:0.026047 validation_1-logloss:0.273228
[10] validation_0-error:0.004813 validation_0-logloss:0.247238 validation_1-error:0.008372 validation_1-logloss:0.253542
[11] validation_0-error:0.004813 validation_0-logloss:0.229689 validation_1-error:0.008372 validation_1-logloss:0.236248
[12] validation_0-error:0.010085 validation_0-logloss:0.210475 validation_1-error:0.015349 validation_1-logloss:0.216868
[13] validation_0-error:0.015586 validation_0-logloss:0.193727 validation_1-error:0.02093 validation_1-logloss:0.199968
[14] validation_0-error:0.015586 validation_0-logloss:0.179108 validation_1-error:0.02093 validation_1-logloss:0.185209
[15] validation_0-error:0.015586 validation_0-logloss:0.166333 validation_1-error:0.02093 validation_1-logloss:0.172308
[16] validation_0-error:0.015586 validation_0-logloss:0.15516 validation_1-error:0.02093 validation_1-logloss:0.16102
[17] validation_0-error:0.015586 validation_0-logloss:0.145382 validation_1-error:0.02093 validation_1-logloss:0.151137
[18] validation_0-error:0.015586 validation_0-logloss:0.13682 validation_1-error:0.02093 validation_1-logloss:0.142481
[19] validation_0-error:0.015586 validation_0-logloss:0.129854 validation_1-error:0.02093 validation_1-logloss:0.135452
[20] validation_0-error:0.015586 validation_0-logloss:0.122889 validation_1-error:0.02093 validation_1-logloss:0.128415
[21] validation_0-error:0.023608 validation_0-logloss:0.11724 validation_1-error:0.029302 validation_1-logloss:0.122718
[22] validation_0-error:0.023608 validation_0-logloss:0.111548 validation_1-error:0.029302 validation_1-logloss:0.116973
[23] validation_0-error:0.02017 validation_0-logloss:0.106935 validation_1-error:0.024186 validation_1-logloss:0.112492
[24] validation_0-error:0.02017 validation_0-logloss:0.102711 validation_1-error:0.024186 validation_1-logloss:0.108251
[25] validation_0-error:0.02017 validation_0-logloss:0.098366 validation_1-error:0.024186 validation_1-logloss:0.103854
[26] validation_0-error:0.02017 validation_0-logloss:0.094848 validation_1-error:0.024186 validation_1-logloss:0.100122
[27] validation_0-error:0.02017 validation_0-logloss:0.09125 validation_1-error:0.024186 validation_1-logloss:0.096787
[28] validation_0-error:0.02017 validation_0-logloss:0.087968 validation_1-error:0.024186 validation_1-logloss:0.093459
[29] validation_0-error:0.02017 validation_0-logloss:0.084816 validation_1-error:0.024186 validation_1-logloss:0.090229
[30] validation_0-error:0.02017 validation_0-logloss:0.081983 validation_1-error:0.024186 validation_1-logloss:0.087354
[31] validation_0-error:0.02017 validation_0-logloss:0.079313 validation_1-error:0.024186 validation_1-logloss:0.084619
[32] validation_0-error:0.012148 validation_0-logloss:0.074708 validation_1-error:0.015814 validation_1-logloss:0.080086
[33] validation_0-error:0.012148 validation_0-logloss:0.071661 validation_1-error:0.015814 validation_1-logloss:0.077247
[34] validation_0-error:0.02017 validation_0-logloss:0.069014 validation_1-error:0.024186 validation_1-logloss:0.074588
[35] validation_0-error:0.014669 validation_0-logloss:0.06648 validation_1-error:0.018605 validation_1-logloss:0.072239
[36] validation_0-error:0.009397 validation_0-logloss:0.064195 validation_1-error:0.011628 validation_1-logloss:0.069621
[37] validation_0-error:0.001375 validation_0-logloss:0.062203 validation_1-error:0.003256 validation_1-logloss:0.06757
[38] validation_0-error:0.001375 validation_0-logloss:0.060052 validation_1-error:0.003256 validation_1-logloss:0.065462
[39] validation_0-error:0.001375 validation_0-logloss:0.05799 validation_1-error:0.003256 validation_1-logloss:0.063569
[40] validation_0-error:0.001375 validation_0-logloss:0.056169 validation_1-error:0.003256 validation_1-logloss:0.061491
[41] validation_0-error:0.001375 validation_0-logloss:0.054376 validation_1-error:0.003256 validation_1-logloss:0.059743
[42] validation_0-error:0.009397 validation_0-logloss:0.052657 validation_1-error:0.011628 validation_1-logloss:0.058177
[43] validation_0-error:0.001375 validation_0-logloss:0.051002 validation_1-error:0.003256 validation_1-logloss:0.056733
[44] validation_0-error:0.001375 validation_0-logloss:0.049429 validation_1-error:0.003256 validation_1-logloss:0.054922
[45] validation_0-error:0.001375 validation_0-logloss:0.047924 validation_1-error:0.003256 validation_1-logloss:0.053362
[46] validation_0-error:0.001375 validation_0-logloss:0.046491 validation_1-error:0.003256 validation_1-logloss:0.051973
[47] validation_0-error:0.001375 validation_0-logloss:0.045115 validation_1-error:0.003256 validation_1-logloss:0.050731
[48] validation_0-error:0.001375 validation_0-logloss:0.04384 validation_1-error:0.003256 validation_1-logloss:0.049218
[49] validation_0-error:0.001375 validation_0-logloss:0.04261 validation_1-error:0.003256 validation_1-logloss:0.048026
[50] validation_0-error:0.001375 validation_0-logloss:0.041414 validation_1-error:0.003256 validation_1-logloss:0.046635
[51] validation_0-error:0.001375 validation_0-logloss:0.04024 validation_1-error:0.003256 validation_1-logloss:0.04559
[52] validation_0-error:0.001375 validation_0-logloss:0.039108 validation_1-error:0.003256 validation_1-logloss:0.044651
[53] validation_0-error:0.001375 validation_0-logloss:0.038046 validation_1-error:0.003256 validation_1-logloss:0.043404
[54] validation_0-error:0.001375 validation_0-logloss:0.036975 validation_1-error:0.003256 validation_1-logloss:0.042286
[55] validation_0-error:0.001375 validation_0-logloss:0.035982 validation_1-error:0.003256 validation_1-logloss:0.041341
[56] validation_0-error:0.001375 validation_0-logloss:0.035031 validation_1-error:0.003256 validation_1-logloss:0.040505
[57] validation_0-error:0.001375 validation_0-logloss:0.034135 validation_1-error:0.003256 validation_1-logloss:0.039399
[58] validation_0-error:0.001375 validation_0-logloss:0.033276 validation_1-error:0.003256 validation_1-logloss:0.038583
[59] validation_0-error:0.001375 validation_0-logloss:0.032452 validation_1-error:0.003256 validation_1-logloss:0.037861
[60] validation_0-error:0.001375 validation_0-logloss:0.031655 validation_1-error:0.003256 validation_1-logloss:0.036928
[61] validation_0-error:0.001375 validation_0-logloss:0.030869 validation_1-error:0.003256 validation_1-logloss:0.035987
[62] validation_0-error:0.001375 validation_0-logloss:0.030057 validation_1-error:0.003256 validation_1-logloss:0.035138
[63] validation_0-error:0.001375 validation_0-logloss:0.029379 validation_1-error:0.003256 validation_1-logloss:0.034418
[64] validation_0-error:0.001375 validation_0-logloss:0.028683 validation_1-error:0.003256 validation_1-logloss:0.033762
[65] validation_0-error:0.001375 validation_0-logloss:0.028014 validation_1-error:0.003256 validation_1-logloss:0.033187
[66] validation_0-error:0.001375 validation_0-logloss:0.027338 validation_1-error:0.003256 validation_1-logloss:0.032326
[67] validation_0-error:0.001375 validation_0-logloss:0.026727 validation_1-error:0.003256 validation_1-logloss:0.031581
[68] validation_0-error:0.001375 validation_0-logloss:0.026087 validation_1-error:0.003256 validation_1-logloss:0.031107
[69] validation_0-error:0.001375 validation_0-logloss:0.025474 validation_1-error:0.003256 validation_1-logloss:0.030427
[70] validation_0-error:0.001375 validation_0-logloss:0.024911 validation_1-error:0.003256 validation_1-logloss:0.029905
[71] validation_0-error:0.001375 validation_0-logloss:0.024368 validation_1-error:0.003256 validation_1-logloss:0.029239
[72] validation_0-error:0.001375 validation_0-logloss:0.023829 validation_1-error:0.003256 validation_1-logloss:0.028852
[73] validation_0-error:0.001375 validation_0-logloss:0.023316 validation_1-error:0.003256 validation_1-logloss:0.028419
[74] validation_0-error:0.001375 validation_0-logloss:0.02278 validation_1-error:0.003256 validation_1-logloss:0.027854
[75] validation_0-error:0.001375 validation_0-logloss:0.022305 validation_1-error:0.003256 validation_1-logloss:0.027263
[76] validation_0-error:0.001375 validation_0-logloss:0.021837 validation_1-error:0.003256 validation_1-logloss:0.026841
[77] validation_0-error:0.001375 validation_0-logloss:0.02139 validation_1-error:0.003256 validation_1-logloss:0.02647
[78] validation_0-error:0.001375 validation_0-logloss:0.020914 validation_1-error:0.003256 validation_1-logloss:0.02589
[79] validation_0-error:0.001375 validation_0-logloss:0.020452 validation_1-error:0.003256 validation_1-logloss:0.025369
[80] validation_0-error:0.001375 validation_0-logloss:0.020058 validation_1-error:0.003256 validation_1-logloss:0.024872
[81] validation_0-error:0.001375 validation_0-logloss:0.019648 validation_1-error:0.003256 validation_1-logloss:0.024367
[82] validation_0-error:0.001375 validation_0-logloss:0.019268 validation_1-error:0.003256 validation_1-logloss:0.023936
[83] validation_0-error:0.001375 validation_0-logloss:0.018878 validation_1-error:0.003256 validation_1-logloss:0.023496
[84] validation_0-error:0.001375 validation_0-logloss:0.018503 validation_1-error:0.003256 validation_1-logloss:0.023169
[85] validation_0-error:0.001375 validation_0-logloss:0.018148 validation_1-error:0.003256 validation_1-logloss:0.022877
[86] validation_0-error:0.001375 validation_0-logloss:0.017783 validation_1-error:0.003256 validation_1-logloss:0.022427
[87] validation_0-error:0.001375 validation_0-logloss:0.01746 validation_1-error:0.003256 validation_1-logloss:0.022145
[88] validation_0-error:0.001375 validation_0-logloss:0.017149 validation_1-error:0.003256 validation_1-logloss:0.021805
[89] validation_0-error:0.001375 validation_0-logloss:0.016832 validation_1-error:0.003256 validation_1-logloss:0.021546
[90] validation_0-error:0.001375 validation_0-logloss:0.016305 validation_1-error:0.003256 validation_1-logloss:0.020802
[91] validation_0-error:0.001375 validation_0-logloss:0.016013 validation_1-error:0.003256 validation_1-logloss:0.020549
[92] validation_0-error:0.001375 validation_0-logloss:0.015729 validation_1-error:0.003256 validation_1-logloss:0.02018
[93] validation_0-error:0.001375 validation_0-logloss:0.015467 validation_1-error:0.003256 validation_1-logloss:0.019926
[94] validation_0-error:0.001375 validation_0-logloss:0.015202 validation_1-error:0.003256 validation_1-logloss:0.019611
[95] validation_0-error:0.001375 validation_0-logloss:0.014931 validation_1-error:0.003256 validation_1-logloss:0.019267
[96] validation_0-error:0.001375 validation_0-logloss:0.014652 validation_1-error:0.003256 validation_1-logloss:0.018949
[97] validation_0-error:0.001375 validation_0-logloss:0.014399 validation_1-error:0.003256 validation_1-logloss:0.018651
[98] validation_0-error:0.001375 validation_0-logloss:0.014151 validation_1-error:0.003256 validation_1-logloss:0.018445
[99] validation_0-error:0.001375 validation_0-logloss:0.013908 validation_1-error:0.003256 validation_1-logloss:0.018252XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,max_depth=2, min_child_weight=1, missing=None, n_estimators=100,n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, subsample=1)
plt.rcParams["figure.figsize"] = (5., 3.)
result = xgbc.evals_result()epochs = len(result["validation_0"]["error"])fig, ax = plt.subplots()
ax.plot(list(range(epochs)), result["validation_0"]["error"], label="train")
ax.plot(list(range(epochs)), result["validation_1"]["error"], label="validation")
ax.legend()
plt.ylabel("error")
plt.xlabel("epoch")
plt.title("XGBoost error")
plt.show()fig, ax = plt.subplots()
ax.plot(list(range(epochs)), result["validation_0"]["logloss"], label="train")
ax.plot(list(range(epochs)), result["validation_1"]["logloss"], label="validation")
ax.legend()
plt.ylabel("logloss")
plt.xlabel("epoch")
plt.title("XGBoost logloss")
plt.show()
# performance in the test set
pred_test = xgbc.predict(X_test)
pred_test = [round(x) for x in pred_test]
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, pred_test) * 100))
Test Accuracy: 99.81%
1.3.5 - early stop
一种防止过拟合的方法
- 监控模型在校验集上的性能:如果在经过固定次数的迭代,校验集上的性能不再提高时,结束训练过程
eval_set = [(X_validation, y_validation)]
xgbc.fit(X_train, y_train, eval_set=eval_set, eval_metric="error", early_stopping_rounds=10, verbose=True)
[0] validation_0-error:0.051163
Will train until validation_0-error hasn't improved in 10 rounds.
[1] validation_0-error:0.046512
[2] validation_0-error:0.051163
[3] validation_0-error:0.046512
[4] validation_0-error:0.046512
[5] validation_0-error:0.046512
[6] validation_0-error:0.026047
[7] validation_0-error:0.046512
[8] validation_0-error:0.046512
[9] validation_0-error:0.026047
[10] validation_0-error:0.008372
[11] validation_0-error:0.008372
[12] validation_0-error:0.015349
[13] validation_0-error:0.02093
[14] validation_0-error:0.02093
[15] validation_0-error:0.02093
[16] validation_0-error:0.02093
[17] validation_0-error:0.02093
[18] validation_0-error:0.02093
[19] validation_0-error:0.02093
[20] validation_0-error:0.02093
Stopping. Best iteration:
[10] validation_0-error:0.008372XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,max_depth=2, min_child_weight=1, missing=None, n_estimators=100,n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, subsample=1)
result = xgbc.evals_result()
plt.plot(list(range(len(result["validation_0"]["error"]))), result["validation_0"]["error"])
plt.ylabel("error")
plt.title("XGBoost error-early stop")
plt.show()
pred_test = xgbc.predict(X_test)
pred_test = [1 if x >= 0.5 else 0 for x in pred_test]
print("Train Accuracy: %.4f" % (accuracy_score(pred_test, y_test)))
Train Accuracy: 0.9808
1.3.6 - 交叉验证cross validation
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
kflods = StratifiedKFold(n_splits=10, random_state=42)
print(kflods)
results = cross_val_score(xgbc, X_train, y_train, cv=kflods)
StratifiedKFold(n_splits=10, random_state=42, shuffle=False)
print(results)
print("%.2f%%, %.2f%%" % (results.mean() * 100, results.std() *100))
[0.99771167 0.99771167 1. 1. 0.99542334 0.997706420.99770642 1. 1. 1. ]
99.86%, 0.15%
1.3.7 - GridSearchCV
from sklearn.grid_search import GridSearchCV
xgbc = XGBClassifier(max_depth=2, objective="binary:logistic")
param_search = {"n_estimators":list(range(1, 10, 1)),"learning_rate":[x/10 for x in list(range(1, 11, 1))]
}
clf = GridSearchCV(estimator=xgbc, param_grid=param_search, cv=5)
clf.fit(X_train, y_train)
GridSearchCV(cv=5, error_score='raise',estimator=XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,max_depth=2, min_child_weight=1, missing=None, n_estimators=100,n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,silent=True, subsample=1),fit_params={}, iid=True, n_jobs=1,param_grid={'learning_rate': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 'n_estimators': [1, 2, 3, 4, 5, 6, 7, 8, 9]},pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
clf.grid_scores_
[mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.1, 'n_estimators': 1},mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 2},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.1, 'n_estimators': 3},mean: 0.95874, std: 0.01161, params: {'learning_rate': 0.1, 'n_estimators': 4},mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 5},mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.1, 'n_estimators': 6},mean: 0.96997, std: 0.01554, params: {'learning_rate': 0.1, 'n_estimators': 7},mean: 0.96379, std: 0.01191, params: {'learning_rate': 0.1, 'n_estimators': 8},mean: 0.96402, std: 0.01220, params: {'learning_rate': 0.1, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.2, 'n_estimators': 1},mean: 0.95920, std: 0.00824, params: {'learning_rate': 0.2, 'n_estimators': 2},mean: 0.97181, std: 0.01766, params: {'learning_rate': 0.2, 'n_estimators': 3},mean: 0.95966, std: 0.01175, params: {'learning_rate': 0.2, 'n_estimators': 4},mean: 0.97570, std: 0.01759, params: {'learning_rate': 0.2, 'n_estimators': 5},mean: 0.97937, std: 0.01934, params: {'learning_rate': 0.2, 'n_estimators': 6},mean: 0.98212, std: 0.00940, params: {'learning_rate': 0.2, 'n_estimators': 7},mean: 0.98441, std: 0.00484, params: {'learning_rate': 0.2, 'n_estimators': 8},mean: 0.97914, std: 0.00751, params: {'learning_rate': 0.2, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.3, 'n_estimators': 1},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.3, 'n_estimators': 2},mean: 0.97800, std: 0.00585, params: {'learning_rate': 0.3, 'n_estimators': 3},mean: 0.96081, std: 0.00956, params: {'learning_rate': 0.3, 'n_estimators': 4},mean: 0.98441, std: 0.00320, params: {'learning_rate': 0.3, 'n_estimators': 5},mean: 0.97937, std: 0.00397, params: {'learning_rate': 0.3, 'n_estimators': 6},mean: 0.98556, std: 0.00426, params: {'learning_rate': 0.3, 'n_estimators': 7},mean: 0.97823, std: 0.00579, params: {'learning_rate': 0.3, 'n_estimators': 8},mean: 0.97983, std: 0.00604, params: {'learning_rate': 0.3, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.4, 'n_estimators': 1},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.4, 'n_estimators': 2},mean: 0.97800, std: 0.00585, params: {'learning_rate': 0.4, 'n_estimators': 3},mean: 0.96768, std: 0.00711, params: {'learning_rate': 0.4, 'n_estimators': 4},mean: 0.97548, std: 0.00642, params: {'learning_rate': 0.4, 'n_estimators': 5},mean: 0.97479, std: 0.00660, params: {'learning_rate': 0.4, 'n_estimators': 6},mean: 0.98419, std: 0.00493, params: {'learning_rate': 0.4, 'n_estimators': 7},mean: 0.99083, std: 0.00576, params: {'learning_rate': 0.4, 'n_estimators': 8},mean: 0.99335, std: 0.00255, params: {'learning_rate': 0.4, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.5, 'n_estimators': 1},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.5, 'n_estimators': 2},mean: 0.97593, std: 0.00416, params: {'learning_rate': 0.5, 'n_estimators': 3},mean: 0.97112, std: 0.00926, params: {'learning_rate': 0.5, 'n_estimators': 4},mean: 0.98694, std: 0.00395, params: {'learning_rate': 0.5, 'n_estimators': 5},mean: 0.98143, std: 0.00603, params: {'learning_rate': 0.5, 'n_estimators': 6},mean: 0.99198, std: 0.00507, params: {'learning_rate': 0.5, 'n_estimators': 7},mean: 0.99404, std: 0.00443, params: {'learning_rate': 0.5, 'n_estimators': 8},mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.5, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.6, 'n_estimators': 1},mean: 0.95554, std: 0.00936, params: {'learning_rate': 0.6, 'n_estimators': 2},mean: 0.97548, std: 0.00642, params: {'learning_rate': 0.6, 'n_estimators': 3},mean: 0.97410, std: 0.00751, params: {'learning_rate': 0.6, 'n_estimators': 4},mean: 0.98304, std: 0.00653, params: {'learning_rate': 0.6, 'n_estimators': 5},mean: 0.99244, std: 0.00792, params: {'learning_rate': 0.6, 'n_estimators': 6},mean: 0.99771, std: 0.00218, params: {'learning_rate': 0.6, 'n_estimators': 7},mean: 0.99794, std: 0.00222, params: {'learning_rate': 0.6, 'n_estimators': 8},mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.6, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.7, 'n_estimators': 1},mean: 0.97387, std: 0.01725, params: {'learning_rate': 0.7, 'n_estimators': 2},mean: 0.97823, std: 0.00610, params: {'learning_rate': 0.7, 'n_estimators': 3},mean: 0.97983, std: 0.00726, params: {'learning_rate': 0.7, 'n_estimators': 4},mean: 0.99060, std: 0.00275, params: {'learning_rate': 0.7, 'n_estimators': 5},mean: 0.99427, std: 0.00162, params: {'learning_rate': 0.7, 'n_estimators': 6},mean: 0.99679, std: 0.00183, params: {'learning_rate': 0.7, 'n_estimators': 7},mean: 0.99702, std: 0.00186, params: {'learning_rate': 0.7, 'n_estimators': 8},mean: 0.99702, std: 0.00186, params: {'learning_rate': 0.7, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.8, 'n_estimators': 1},mean: 0.97227, std: 0.00856, params: {'learning_rate': 0.8, 'n_estimators': 2},mean: 0.98075, std: 0.00708, params: {'learning_rate': 0.8, 'n_estimators': 3},mean: 0.98579, std: 0.00674, params: {'learning_rate': 0.8, 'n_estimators': 4},mean: 0.99404, std: 0.00577, params: {'learning_rate': 0.8, 'n_estimators': 5},mean: 0.99794, std: 0.00255, params: {'learning_rate': 0.8, 'n_estimators': 6},mean: 0.99885, std: 0.00102, params: {'learning_rate': 0.8, 'n_estimators': 7},mean: 0.99908, std: 0.00086, params: {'learning_rate': 0.8, 'n_estimators': 8},mean: 0.99862, std: 0.00134, params: {'learning_rate': 0.8, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 0.9, 'n_estimators': 1},mean: 0.97937, std: 0.00397, params: {'learning_rate': 0.9, 'n_estimators': 2},mean: 0.98900, std: 0.00809, params: {'learning_rate': 0.9, 'n_estimators': 3},mean: 0.98487, std: 0.00689, params: {'learning_rate': 0.9, 'n_estimators': 4},mean: 0.99496, std: 0.00438, params: {'learning_rate': 0.9, 'n_estimators': 5},mean: 0.99565, std: 0.00302, params: {'learning_rate': 0.9, 'n_estimators': 6},mean: 0.99931, std: 0.00056, params: {'learning_rate': 0.9, 'n_estimators': 7},mean: 0.99817, std: 0.00200, params: {'learning_rate': 0.9, 'n_estimators': 8},mean: 0.99817, std: 0.00200, params: {'learning_rate': 0.9, 'n_estimators': 9},mean: 0.95576, std: 0.00954, params: {'learning_rate': 1.0, 'n_estimators': 1},mean: 0.97937, std: 0.00397, params: {'learning_rate': 1.0, 'n_estimators': 2},mean: 0.98969, std: 0.00763, params: {'learning_rate': 1.0, 'n_estimators': 3},mean: 0.98648, std: 0.00519, params: {'learning_rate': 1.0, 'n_estimators': 4},mean: 0.99450, std: 0.00197, params: {'learning_rate': 1.0, 'n_estimators': 5},mean: 0.99633, std: 0.00152, params: {'learning_rate': 1.0, 'n_estimators': 6},mean: 0.99817, std: 0.00200, params: {'learning_rate': 1.0, 'n_estimators': 7},mean: 0.99931, std: 0.00056, params: {'learning_rate': 1.0, 'n_estimators': 8},mean: 0.99931, std: 0.00056, params: {'learning_rate': 1.0, 'n_estimators': 9}]
clf.best_score_
0.9993123997249599
clf.best_params_
{'learning_rate': 0.9, 'n_estimators': 7}
pred_val = clf.predict(X_validation)
print("Validation Accuracy: %.2f%%" % (accuracy_score(y_validation, [round(x) for x in pred_val])))
Validation Accuracy: 1.00%
pred_test = clf.predict(X_test)
print("Test Accuracy: %.2f%%" % (accuracy_score(y_test, [round(x) for x in pred_test])))