CatBoost vs. LightGBM vs. XGBoost:誰是提升機(jī)中的王者?
提升算法是一類機(jī)器學(xué)習(xí)算法,通過迭代地訓(xùn)練一系列弱分類器(通常是決策樹)來構(gòu)建一個(gè)強(qiáng)分類器。在每一輪迭代中,新的分類器被設(shè)計(jì)為修正前一輪分類器的錯(cuò)誤,從而逐步提高整體的分類性能。
盡管神經(jīng)網(wǎng)絡(luò)興起并流行起來,但提升算法仍然相當(dāng)實(shí)用。因?yàn)樗鼈冊(cè)谟?xùn)練數(shù)據(jù)有限、訓(xùn)練時(shí)間短、缺乏參數(shù)調(diào)優(yōu)專業(yè)知識(shí)等的情況下,仍然有良好的表現(xiàn)。
提升算法有AdaBoost、CatBoost、LightGBM、XGBoost等。
本文,將重點(diǎn)關(guān)注CatBoost、LightGBM、XGBoost。將包括:
- 結(jié)構(gòu)上的區(qū)別;
- 每個(gè)算法對(duì)分類變量的處理方式;
- 理解參數(shù);
- 在數(shù)據(jù)集上的實(shí)踐;
- 每個(gè)算法的性能。
文章來自:https://towardsdatascience.com/catboost-vs-light-gbm-vs-xgboost-5f93620723db
為適合中文閱讀習(xí)慣,閱讀更有代入感,原文翻譯后有刪改。
Hunter Phillips | 作者
羅伯特 | 編輯
由于 XGBoost(通常被稱為 GBM Killer)在機(jī)器學(xué)習(xí)領(lǐng)域已經(jīng)存在了很長時(shí)間,并且有很多文章專門介紹它,因此本文將更多地關(guān)注 CatBoost 和 LGBM。
LightGBM使用一種新穎的梯度單邊采樣(Gradient-based One-Side Sampling,GOSS)技術(shù),在查找分裂值時(shí)過濾數(shù)據(jù)實(shí)例,而XGBoost使用預(yù)排序算法(pre-sorted algorithm)和基于直方圖的算法(Histogram-based algorithm)來計(jì)算最佳分裂。
上面的實(shí)例指的是觀測(cè)/樣本。
首先,讓我們了解一下XGBoost的預(yù)排序分裂是如何工作的:
- 對(duì)于每個(gè)節(jié)點(diǎn),枚舉所有特征;
- 對(duì)于每個(gè)特征,按特征值對(duì)實(shí)例進(jìn)行排序;
- 使用線性掃描來根據(jù)信息增益(information gain)決定該特征上的最佳分裂;
選擇所有特征中的最佳分裂解決方案。
簡單來說,基于直方圖的算法將特征的所有數(shù)據(jù)點(diǎn)分成離散的箱子,并使用這些箱子來找到直方圖的分裂值。雖然在訓(xùn)練速度上比預(yù)排序算法高效,后者需要枚舉預(yù)排序的特征值上的所有可能分裂點(diǎn),但在速度方面仍然落后于GOSS。
那么,是什么使得GOSS方法高效呢?
在AdaBoost中,樣本權(quán)重可以作為樣本重要性的良好指標(biāo)。然而,在梯度提升決策樹(GBDT)中,沒有原生的樣本權(quán)重,因此無法直接應(yīng)用于AdaBoost提出的采樣方法。這就引入了基于梯度的采樣方法。
梯度代表損失函數(shù)切線的斜率,因此在某種意義上,如果數(shù)據(jù)點(diǎn)的梯度較大,這些點(diǎn)對(duì)于找到最佳分裂點(diǎn)是重要的,因?yàn)樗鼈兙哂懈叩恼`差。
GOSS保留所有具有較大梯度的實(shí)例,并對(duì)具有較小梯度的實(shí)例進(jìn)行隨機(jī)采樣。例如,假設(shè)我有50萬行的數(shù)據(jù),其中1萬行具有較大的梯度。因此,我的算法將選擇(10k行具有較大梯度 + 剩余的490k行的x%隨機(jī)選擇)。假設(shè)x為10%,則選擇的總行數(shù)是59k,基于這些行找到了分裂值。
這里的基本假設(shè)是,具有較小梯度的訓(xùn)練實(shí)例具有較小的訓(xùn)練誤差,并且已經(jīng)訓(xùn)練得很好。為了保持相同的數(shù)據(jù)分布,在計(jì)算信息增益時(shí),GOSS引入了一個(gè)常數(shù)乘數(shù),用于具有較小梯度的數(shù)據(jù)實(shí)例。因此,GOSS在減少數(shù)據(jù)實(shí)例數(shù)量和保持學(xué)習(xí)決策樹的準(zhǔn)確性之間取得了良好的平衡。
LGBM在梯度/誤差較大的葉子上進(jìn)一步生長
2. 每個(gè)模型如何處理分類變量?2.1 CatBoostCatBoost具有靈活性,可以提供分類列的索引,以便可以使用one-hot編碼進(jìn)行編碼,使用one_hot_max_size參數(shù)(對(duì)于具有不同值數(shù)量小于或等于給定參數(shù)值的所有特征使用one-hot編碼)。
如果在cat_features參數(shù)中未傳遞任何內(nèi)容,則CatBoost將將所有列視為數(shù)值變量。
注意:如果一個(gè)包含字符串值的列沒有在cat_features中提供,CatBoost會(huì)拋出錯(cuò)誤。另外,默認(rèn)為int類型的列將默認(rèn)視為數(shù)值型,如果要將其視為分類變量,必須在cat_features中指定。
對(duì)于剩余的分類列,其中唯一類別數(shù)大于one_hot_max_size的列,CatBoost使用一種類似于均值編碼但減少過擬合的高效編碼方法。該過程如下:
- 隨機(jī)以隨機(jī)順序?qū)斎胗^測(cè)集進(jìn)行排列,生成多個(gè)隨機(jī)排列;
- 將標(biāo)簽值從浮點(diǎn)數(shù)或類別轉(zhuǎn)換為整數(shù);
使用以下公式將所有分類特征值轉(zhuǎn)換為數(shù)值:
其中,countInClass表示標(biāo)簽值等于“1”的對(duì)象中當(dāng)前分類特征值的出現(xiàn)次數(shù),prior是分子的初步值,由起始參數(shù)確定,totalCount是具有與當(dāng)前分類特征值匹配的當(dāng)前對(duì)象之前的總對(duì)象數(shù)。
數(shù)學(xué)上,可以用以下方程表示:
2.2 LightGBM與CatBoost類似,LightGBM也可以通過輸入特征名稱來處理分類特征。它不會(huì)轉(zhuǎn)換為獨(dú)熱編碼,而且比獨(dú)熱編碼快得多。LGBM使用一種特殊的算法來找到分類特征的分裂值。
2.3 XGBoost注意:在構(gòu)建LGBM數(shù)據(jù)集之前,您應(yīng)該將分類特征轉(zhuǎn)換為整數(shù)類型。即使通過categorical_feature參數(shù)傳遞了字符串值,它也不接受字符串值。
與CatBoost或LGBM不同,XGBoost本身不能處理分類特征,它只接受類似于隨機(jī)森林的數(shù)值型數(shù)據(jù)。因此,在將分類數(shù)據(jù)提供給XGBoost之前,需要執(zhí)行各種編碼,如標(biāo)簽編碼、均值編碼或獨(dú)熱編碼。
3. 理解參數(shù)所有這些模型都有很多要調(diào)整的參數(shù),但我們只討論其中重要的參數(shù)。下面是這些參數(shù)的列表,根據(jù)它們的功能以及在不同模型中的對(duì)應(yīng)參數(shù)。
4. 在數(shù)據(jù)集上的實(shí)現(xiàn)
我使用了2015年航班延誤的Kaggle數(shù)據(jù)集,因?yàn)樗劝诸愄卣饔职瑪?shù)值特征。由于大約有500萬行數(shù)據(jù),這個(gè)數(shù)據(jù)集對(duì)于評(píng)估每種類型的提升模型在速度和準(zhǔn)確性方面的性能是很好的。我將使用這個(gè)數(shù)據(jù)的10%子集,約50萬行。
以下是用于建模的特征:
- MONTH,DAY,DAY_OF_WEEK:數(shù)據(jù)類型int
- AIRLINE和FLIGHT_NUMBER:數(shù)據(jù)類型int
- ORIGIN_AIRPORT和DESTINATION_AIRPORT:數(shù)據(jù)類型字符串
- DEPARTURE_TIME:數(shù)據(jù)類型float
- ARRIVAL_DELAY:這將是目標(biāo)變量,并轉(zhuǎn)換為表示超過10分鐘延誤的布爾變量
DISTANCE和AIR_TIME:數(shù)據(jù)類型float
import pandas as pd, numpy as np, timefrom sklearn.model_selection import train_test_split
data = pd.read_csv("./data/flights.csv")data = data.sample(frac = 0.1, random_state=10)
data = data[["MONTH","DAY","DAY_OF_WEEK","AIRLINE","FLIGHT_NUMBER","DESTINATION_AIRPORT", "ORIGIN_AIRPORT","AIR_TIME", "DEPARTURE_TIME","DISTANCE","ARRIVAL_DELAY"]]data.dropna(inplace=True)
data["ARRIVAL_DELAY"] = (data["ARRIVAL_DELAY"]>10)*1
cols = ["AIRLINE","FLIGHT_NUMBER","DESTINATION_AIRPORT","ORIGIN_AIRPORT"]for item in cols: data[item] = data[item].astype("category").cat.codes + 1
train, test, y_train, y_test = train_test_split(data.drop(["ARRIVAL_DELAY"], axis=1), data["ARRIVAL_DELAY"],random_state=10, test_size=0.25)4.1 XGBoost
import xgboost as xgbfrom sklearn import metricsfrom sklearn.model_selection import GridSearchCV
def auc(m, train, test): return (metrics.roc_auc_score(y_train,m.predict_proba(train)[:,1]), metrics.roc_auc_score(y_test,m.predict_proba(test)[:,1]))
# Parameter Tuningmodel = xgb.XGBClassifier()param_dist = {"max_depth": [10,30,50], "min_child_weight" : [1,3,6], "n_estimators": [200], "learning_rate": [0.05, 0.1,0.16],}grid_search = GridSearchCV(model, param_grid=param_dist, cv = 3, verbose=10, n_jobs=-1)grid_search.fit(train, y_train)
grid_search.best_estimator_
model = xgb.XGBClassifier(max_depth=50, min_child_weight=1, n_estimators=200,\ n_jobs=-1 , verbose=1,learning_rate=0.16)model.fit(train,y_train)
auc(model, train, test)4.2 LightGBM
4.3 CatBoostimport lightgbm as lgbfrom sklearn import metrics
def auc2(m, train, test): return (metrics.roc_auc_score(y_train,m.predict(train)), metrics.roc_auc_score(y_test,m.predict(test)))
lg = lgb.LGBMClassifier(verbose=0)param_dist = {"max_depth": [25,50, 75], "learning_rate" : [0.01,0.05,0.1], "num_leaves": [300,900,1200], "n_estimators": [200] }grid_search = GridSearchCV(lg, n_jobs=-1, param_grid=param_dist, cv = 3, scoring="roc_auc", verbose=5)grid_search.fit(train,y_train)grid_search.best_estimator_
d_train = lgb.Dataset(train, label=y_train)params = {"max_depth": 50, "learning_rate" : 0.1, "num_leaves": 900, "n_estimators": 300}
# Without Categorical Featuresmodel2 = lgb.train(params, d_train)auc2(model2, train, test)
# With Catgeorical Featurescate_features_name = ["MONTH","DAY","DAY_OF_WEEK","AIRLINE","DESTINATION_AIRPORT", "ORIGIN_AIRPORT"]model2 = lgb.train(params, d_train, categorical_feature = cate_features_name)auc2(model2, train, test)
在調(diào)整CatBoost的參數(shù)時(shí),很難傳遞分類特征的索引。因此,我在沒有傳遞分類特征的情況下調(diào)整了參數(shù),并評(píng)估了兩個(gè)模型——一個(gè)使用分類特征,另一個(gè)不使用分類特征。我單獨(dú)調(diào)整了one_hot_max_size,因?yàn)樗粫?huì)影響其他參數(shù)。
import catboost as cbcat_features_index = [0,1,2,3,4,5,6]5. 結(jié)論
def auc(m, train, test): return (metrics.roc_auc_score(y_train,m.predict_proba(train)[:,1]), metrics.roc_auc_score(y_test,m.predict_proba(test)[:,1]))
params = {'depth': [4, 7, 10], 'learning_rate' : [0.03, 0.1, 0.15], 'l2_leaf_reg': [1,4,9], 'iterations': [300]}cb = cb.CatBoostClassifier()cb_model = GridSearchCV(cb, params, scoring="roc_auc", cv = 3)cb_model.fit(train, y_train)
With Categorical featuresclf = cb.CatBoostClassifier(eval_metric="AUC", depth=10, iterations= 500, l2_leaf_reg= 9, learning_rate= 0.15)clf.fit(train,y_train)auc(clf, train, test)
With Categorical featuresclf = cb.CatBoostClassifier(eval_metric="AUC",one_hot_max_size=31, \ depth=10, iterations= 500, l2_leaf_reg= 9, learning_rate= 0.15)clf.fit(train,y_train, cat_features= cat_features_index)auc(clf, train, test)
在評(píng)估模型時(shí),我們應(yīng)該從速度和準(zhǔn)確性兩個(gè)方面考慮模型的性能。
考慮到這一點(diǎn),CatBoost是贏家,測(cè)試集上的準(zhǔn)確率最高(0.816),過擬合最?。ㄓ?xùn)練集和測(cè)試集的準(zhǔn)確率接近)且預(yù)測(cè)時(shí)間和調(diào)優(yōu)時(shí)間最短。但這僅僅是因?yàn)槲覀兛紤]了分類變量并調(diào)整了one_hot_max_size。如果我們不利用CatBoost的這些特性,它的準(zhǔn)確率只有0.752,表現(xiàn)最差。因此,我們得出結(jié)論,CatBoost僅在數(shù)據(jù)中存在分類變量且我們正確調(diào)整它們時(shí)表現(xiàn)良好。
我們的下一個(gè)表現(xiàn)良好的模型是XGBoost。即使忽略了我們?cè)跀?shù)據(jù)中有分類變量并將其轉(zhuǎn)換為數(shù)值變量供XGBoost使用的事實(shí),它的準(zhǔn)確率仍與CatBoost相當(dāng)接近。然而,XGBoost唯一的問題是速度太慢。調(diào)整其參數(shù)真的很令人沮喪,特別是使用GridSearchCV(運(yùn)行GridSearchCV花費(fèi)了我6個(gè)小時(shí),非常糟糕的主意?。8玫姆椒ㄊ菃为?dú)調(diào)整參數(shù),而不是使用GridSearchCV。閱讀這篇博文,了解如何巧妙地調(diào)整參數(shù)。
最后,LightGBM排名最后。這里需要注意的一點(diǎn)是,當(dāng)使用cat_features時(shí),它在速度和準(zhǔn)確性方面表現(xiàn)不佳。我認(rèn)為它表現(xiàn)糟糕的原因是它對(duì)分類數(shù)據(jù)使用了某種修改過的均值編碼,導(dǎo)致過擬合(訓(xùn)練準(zhǔn)確率非常高——0.999,相比之下測(cè)試準(zhǔn)確率較低)。然而,如果像XGBoost那樣正常使用它,它可以以比XGBoost快得多的速度實(shí)現(xiàn)類似(甚至更高)的準(zhǔn)確性(LGBM——0.785,XGBoost——0.789)。
最后,我必須說這些觀察結(jié)果適用于這個(gè)特定的數(shù)據(jù)集,對(duì)于其他數(shù)據(jù)集可能有效也可能無效。然而,一般來說,一個(gè)真實(shí)的情況是XGBoost比其他兩種算法更慢。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。