從頭開始:用Python實(shí)現(xiàn)決策樹算法
決策樹算法是一個(gè)強(qiáng)大的預(yù)測方法,它非常流行。因?yàn)樗鼈兊哪P湍軌蜃屝率州p而易舉地理解得和專家一樣好,所以它們比較流行。同時(shí),最終生成的決策樹能夠解釋做出特定預(yù)測的確切原因,這使它們在實(shí)際運(yùn)用中倍受親睞。
同時(shí),決策樹算法也為更高級(jí)的集成模型(如 bagging、隨機(jī)森林及 gradient boosting)提供了基礎(chǔ)。
在這篇教程中,你將會(huì)從零開始,學(xué)習(xí)如何用 Python 實(shí)現(xiàn)《Classification And Regression Tree algorithm》中所說的內(nèi)容。
在學(xué)完該教程之后,你將會(huì)知道:
如何計(jì)算并評價(jià)數(shù)據(jù)集中地候選分割點(diǎn)(Candidate Split Point)
如何在決策樹結(jié)構(gòu)中排分配這些分割點(diǎn)
如何在實(shí)際問題中應(yīng)用這些分類和回歸算法
一、概要
本節(jié)簡要介紹了關(guān)于分類及回歸樹(Classification and Regression Trees)算法的一些內(nèi)容,并給出了將在本教程中使用的鈔票數(shù)據(jù)集(Banknote Dataset)。
1.1 分類及回歸樹
分類及回歸樹(CART)是由 Leo Breiman 提出的一個(gè)術(shù)語,用來描述一種能被用于分類或者回歸預(yù)測模型問題的回歸樹算法。
我們將在本教程中主要討論 CART 在分類問題上的應(yīng)用。
二叉樹(Binary Tree)是 CART 模型的代表之一。這里所說的二叉樹,與數(shù)據(jù)結(jié)構(gòu)和算法里面所說的二叉樹別無二致,沒有什么特別之處(每個(gè)節(jié)點(diǎn)可以有 0、1 或 2 個(gè)子節(jié)點(diǎn))。
每個(gè)節(jié)點(diǎn)代表在節(jié)點(diǎn)處有一個(gè)輸入變量被傳入,并根據(jù)某些變量被分類(我們假定該變量是數(shù)值型的)。樹的葉節(jié)點(diǎn)(又叫做終端節(jié)點(diǎn),Terminal Node)由輸出變量構(gòu)成,它被用于進(jìn)行預(yù)測。
在樹被創(chuàng)建完成之后,每個(gè)新的數(shù)據(jù)樣本都將按照每個(gè)節(jié)點(diǎn)的分割條件,沿著該樹從頂部往下,直到輸出一個(gè)最終決策。
創(chuàng)建一個(gè)二元分類樹實(shí)際上是一個(gè)分割輸入空間的過程。遞歸二元分類(Recursive Binary Splitting)是一個(gè)被用于分割空間的貪心算法。這實(shí)際上是一個(gè)數(shù)值過程:當(dāng)一系列的輸入值被排列好后,它將嘗試一系列的分割點(diǎn),測試它們分類完后成本函數(shù)(Cost Function)的值。
有最優(yōu)成本函數(shù)(通常是最小的成本函數(shù),因?yàn)槲覀兺M撝底钚?的分割點(diǎn)將會(huì)被選擇。根據(jù)貪心法(greedy approach)原則,所有的輸入變量和所有可能的分割點(diǎn)都將被測試,并會(huì)基于它們成本函數(shù)的表現(xiàn)被評估。(譯者注:下面簡述對回歸問題和分類問題常用的成本函數(shù)。)
- 回歸問題:對落在分割點(diǎn)確定區(qū)域內(nèi)所有的樣本取誤差平方和(Sum Squared Error)。
- 分類問題:一般采用基尼成本函數(shù)(Gini Cost Function),它能夠表明被分割之后每個(gè)節(jié)點(diǎn)的純凈度(Node Purity)如何。其中,節(jié)點(diǎn)純凈度是一種表明每個(gè)節(jié)點(diǎn)分類后訓(xùn)練數(shù)據(jù)混雜程度的指標(biāo)。
分割將一直進(jìn)行,直到每個(gè)節(jié)點(diǎn)(分類后)都只含有最小數(shù)量的訓(xùn)練樣本或者樹的深度達(dá)到了最大值。
1.2 Banknote 數(shù)據(jù)集
Banknote 數(shù)據(jù)集,需要我們根據(jù)對紙幣照片某些性質(zhì)的分析,來預(yù)測該鈔票的真?zhèn)巍?/p>
該數(shù)據(jù)集中含有 1372 個(gè)樣本,每個(gè)樣本由 5 個(gè)數(shù)值型變量構(gòu)成。這是一個(gè)二元分類問題。如下列舉 5 個(gè)變量的含義及數(shù)據(jù)性質(zhì):
1. 圖像經(jīng)小波變換后的方差(Variance)(連續(xù)值)
2. 圖像經(jīng)小波變換后的偏度(Skewness)(連續(xù)值)
3. 圖像經(jīng)小波變換后的峰度(Kurtosis)(連續(xù)值)
4. 圖像的熵(Entropy)(連續(xù)值)
5. 鈔票所屬類別(整數(shù),離散值)
如下是數(shù)據(jù)集前五行數(shù)據(jù)的樣本。
- 3.6216,8.6661,-2.8073,-0.44699,0
- 4.5459,8.1674,-2.4586,-1.4621,0
- 3.866,-2.6383,1.9242,0.10645,0
- 3.4566,9.5228,-4.0112,-3.5944,0
- 0.32924,-4.4552,4.5718,-0.9888,0
- 4.3684,9.6718,-3.9606,-3.1625,0
使用零規(guī)則算法(Zero Rule Algorithm)來預(yù)測最常出現(xiàn)類別的情況(譯者注:也就是找到最常出現(xiàn)的一類樣本,然后預(yù)測所有的樣本都是這個(gè)類別),對該問的基準(zhǔn)準(zhǔn)確大概是 50%。
你可以在這里下載并了解更多關(guān)于這個(gè)數(shù)據(jù)集的內(nèi)容:UCI Machine Learning Repository。
請下載該數(shù)據(jù)集,放到你當(dāng)前的工作目錄,并重命名該文件為 data_banknote_authentication.csv。
二、教程
本教程分為五大部分:
1. 對基尼系數(shù)(Gini Index)的介紹
2.(如何)創(chuàng)建分割點(diǎn)
3.(如何)生成樹模型
4.(如何)利用模型進(jìn)行預(yù)測
5. 對鈔票數(shù)據(jù)集的案例研究
這些步驟能幫你打好基礎(chǔ),讓你能夠從零實(shí)現(xiàn) CART 算法,并能將它應(yīng)用到你子集的預(yù)測模型問題中。
2.1 基尼系數(shù)
基尼系數(shù)是一種評估數(shù)據(jù)集分割點(diǎn)優(yōu)劣的成本函數(shù)。
數(shù)據(jù)集的分割點(diǎn)是關(guān)于輸入中某個(gè)屬性的分割。對數(shù)據(jù)集中某個(gè)樣本而言,分割點(diǎn)會(huì)根據(jù)某閾值對該樣本對應(yīng)屬性的值進(jìn)行分類。他能根據(jù)訓(xùn)練集中出現(xiàn)的模式將數(shù)據(jù)分為兩類。
基尼系數(shù)通過計(jì)算分割點(diǎn)創(chuàng)建的兩個(gè)類別中數(shù)據(jù)類別的混雜程度,來表現(xiàn)分割點(diǎn)的好壞。一個(gè)完美的分割點(diǎn)對應(yīng)的基尼系數(shù)為 0(譯者注:即在一類中不會(huì)出現(xiàn)另一類的數(shù)據(jù),每個(gè)類都是「純」的),而最差的分割點(diǎn)的基尼系數(shù)則為 1.0(對于二分問題,每一類中出現(xiàn)另一類數(shù)據(jù)的比例都為 50%,也就是數(shù)據(jù)完全沒能被根據(jù)類別不同區(qū)分開)。
下面我們通過一個(gè)具體的例子來說明如何計(jì)算基尼系數(shù)。
我們有兩組數(shù)據(jù),每組有兩行。第一組數(shù)據(jù)中所有行都屬于類別 0(Class 0),第二組數(shù)據(jù)中所有的行都屬于類別 1(Class 1)。這是一個(gè)完美的分割點(diǎn)。
首先我們要按照下式計(jì)算每組數(shù)據(jù)中各類別數(shù)據(jù)的比例:
- proportion = count(class_value) / count(rows)
那么,對本例而言,相應(yīng)的比例為:
- group_1_class_0 = 2 / 2 = 1
- group_1_class_1 = 0 / 2 = 0
- group_2_class_0 = 0 / 2 = 0
- group_2_class_1 = 2 / 2 = 1
基尼系數(shù)按照如下公式計(jì)算:
- gini_index = sum(proportion * (1.0 - proportion))
將本例中所有組、所有類數(shù)據(jù)的比例帶入到上述公式:
- gini_index = (group_1_class_0 * (1.0 - group_1_class_0)) +
- (group_1_class_1 * (1.0 - group_1_class_1)) +
- (group_2_class_0 * (1.0 - group_2_class_0)) +
- (group_2_class_1 * (1.0 - group_2_class_1))
化簡,得:
- gini_index = 0 + 0 + 0 + 0 = 0
如下是一個(gè)叫做 gini_index() 的函數(shù),它能夠計(jì)算給定數(shù)據(jù)的基尼系數(shù)(組、類別都以列表(list)的形式給出)。其中有些算法魯棒性檢測,能夠避免對空組除以 0 的情況。
- # Calculate the Gini index for a split dataset
- def gini_index(groups, class_values):
- gini = 0.0
- for class_value in class_values:
- for group in groups:
- size = len(group)
- if size == 0:
- continue
- proportion = [row[-1] for row in group].count(class_value) / float(size)
- gini += (proportion * (1.0 - proportion))
- return gini
我們可以根據(jù)上例來測試該函數(shù)的運(yùn)行情況,也可以測試最差分割點(diǎn)的情況。完整的代碼如下:
- # Calculate the Gini index for a split dataset
- def gini_index(groups, class_values):
- gini = 0.0
- for class_value in class_values:
- for group in groups:
- size = len(group)
- if size == 0:
- continue
- proportion = [row[-1] for row in group].count(class_value) / float(size)
- gini += (proportion * (1.0 - proportion))
- return gini
- # test Gini values
- print(gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1]))
- print(gini_index([[[1, 0], [1, 0]], [[1, 1], [1, 1]]], [0, 1]))
運(yùn)行該代碼,將會(huì)打印兩個(gè)基尼系數(shù),其中第一個(gè)對應(yīng)的是最差的情況為 1.0,第二個(gè)對應(yīng)的是最好的情況為 0.0。
- 1.0
- 0.0
2.2 創(chuàng)建分割點(diǎn)
一個(gè)分割點(diǎn)由數(shù)據(jù)集中的一個(gè)屬性和一個(gè)閾值構(gòu)成。
我們可以將其總結(jié)為對給定的屬性確定一個(gè)分割數(shù)據(jù)的閾值。這是一種行之有效的分類數(shù)據(jù)的方法。
創(chuàng)建分割點(diǎn)包括三個(gè)步驟,其中第一步已在計(jì)算基尼系數(shù)的部分討論過。余下兩部分分別為:
1. 分割數(shù)據(jù)集。
2. 評價(jià)所有(可行的)分割點(diǎn)。
我們具體看一下每個(gè)步驟。
2.2.1 分割數(shù)據(jù)集
分割數(shù)據(jù)集意味著我們給定數(shù)據(jù)集某屬性(或其位于屬性列表中的下表)及相應(yīng)閾值的情況下,將數(shù)據(jù)集分為兩個(gè)部分。
一旦數(shù)據(jù)被分為兩部分,我們就可以使用基尼系數(shù)來評估該分割的成本函數(shù)。
分割數(shù)據(jù)集需要對每行數(shù)據(jù)進(jìn)行迭代,根據(jù)每個(gè)數(shù)據(jù)點(diǎn)相應(yīng)屬性的值與閾值的大小情況將該數(shù)據(jù)點(diǎn)放到相應(yīng)的部分(對應(yīng)樹結(jié)構(gòu)中的左叉與右叉)。
如下是一個(gè)名為 test_split() 的函數(shù),它能實(shí)現(xiàn)上述功能:
- # Split a dataset based on an attribute and an attribute value
- def test_split(index, value, dataset):
- left, right = list(), list()
- for row in dataset:
- if row[index] < value:
- left.append(row)
- else:
- right.append(row)
- return left, right
代碼還是很簡單的。
注意,在代碼中,屬性值大于或等于閾值的數(shù)據(jù)點(diǎn)被分類到了右組中。
2.2.2 評價(jià)所有分割點(diǎn)
在基尼函數(shù) gini_index() 和分類函數(shù) test_split() 的幫助下,我們可以開始進(jìn)行評估分割點(diǎn)的流程。
對給定的數(shù)據(jù)集,對每一個(gè)屬性,我們都要檢查所有的可能的閾值使之作為候選分割點(diǎn)。然后,我們將根據(jù)這些分割點(diǎn)的成本(cost)對其進(jìn)行評估,最終挑選出最優(yōu)的分割點(diǎn)。
當(dāng)最優(yōu)分割點(diǎn)被找到之后,我們就能用它作為我們決策樹中的一個(gè)節(jié)點(diǎn)。
而這也就是所謂的窮舉型貪心算法。
在該例中,我們將使用一個(gè)詞典來代表決策樹中的一個(gè)節(jié)點(diǎn),它能夠按照變量名儲(chǔ)存數(shù)據(jù)。當(dāng)選擇了最優(yōu)分割點(diǎn)并使用它作為樹的新節(jié)點(diǎn)時(shí),我們存下對應(yīng)屬性的下標(biāo)、對應(yīng)分割值及根據(jù)分割值分割后的兩部分?jǐn)?shù)據(jù)。
分割后地每一組數(shù)據(jù)都是一個(gè)更小規(guī)模地?cái)?shù)據(jù)集(可以繼續(xù)進(jìn)行分割操作),它實(shí)際上就是原始數(shù)據(jù)集中地?cái)?shù)據(jù)按照分割點(diǎn)被分到了左叉或右叉的數(shù)據(jù)集。你可以想象我們可以進(jìn)一步將每一組數(shù)據(jù)再分割,不斷循環(huán)直到建構(gòu)出整個(gè)決策樹。
如下是一個(gè)名為 get_split() 的函數(shù),它能實(shí)現(xiàn)上述的步驟。你會(huì)發(fā)現(xiàn),它遍歷了每一個(gè)屬性(除了類別值)以及屬性對應(yīng)的每一個(gè)值,在每次迭代中它都會(huì)分割數(shù)據(jù)并評估該分割點(diǎn)。
當(dāng)所有的檢查完成后,最優(yōu)的分割點(diǎn)將被記錄并返回。
- # Select the best split point for a dataset
- def get_split(dataset):
- class_values = list(set(row[-1] for row in dataset))
- b_index, b_value, b_score, b_groups = 999, 999, 999, None
- for index in range(len(dataset[0])-1):
- for row in dataset:
- groups = test_split(index, row[index], dataset)
- gini = gini_index(groups, class_values)
- if gini < b_score:
- b_index, b_value, b_score, b_groups = index, row[index], gini, groups
- return {'index':b_index, 'value':b_value, 'groups':b_groups}
我們能在一個(gè)小型合成的數(shù)據(jù)集上來測試這個(gè)函數(shù)以及整個(gè)數(shù)據(jù)集分割的過程。
- X1 X2 Y
- 2.771244718 1.784783929 0
- 1.728571309 1.169761413 0
- 3.678319846 2.81281357 0
- 3.961043357 2.61995032 0
- 2.999208922 2.209014212 0
- 7.497545867 3.162953546 1
- 9.00220326 3.339047188 1
- 7.444542326 0.476683375 1
- 10.12493903 3.234550982 1
- 6.642287351 3.319983761 1
同時(shí),我們可以使用不同顏色標(biāo)記不同的類,將該數(shù)據(jù)集繪制出來。由圖可知,我們可以從 X1 軸(即圖中的 X 軸)上挑出一個(gè)值來分割該數(shù)據(jù)集。
范例所有的代碼整合如下:
- # Split a dataset based on an attribute and an attribute value
- def test_split(index, value, dataset):
- left, right = list(), list()
- for row in dataset:
- if row[index] < value:
- left.append(row)
- else:
- right.append(row)
- return left, right
- # Calculate the Gini index for a split dataset
- def gini_index(groups, class_values):
- gini = 0.0
- for class_value in class_values:
- for group in groups:
- size = len(group)
- if size == 0:
- continue
- proportion = [row[-1] for row in group].count(class_value) / float(size)
- gini += (proportion * (1.0 - proportion))
- return gini
- # Select the best split point for a dataset
- def get_split(dataset):
- class_values = list(set(row[-1] for row in dataset))
- b_index, b_value, b_score, b_groups = 999, 999, 999, None
- for index in range(len(dataset[0])-1):
- for row in dataset:
- groups = test_split(index, row[index], dataset)
- gini = gini_index(groups, class_values)
- print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
- if gini < b_score:
- b_index, b_value, b_score, b_groups = index, row[index], gini, groups
- return {'index':b_index, 'value':b_value, 'groups':b_groups}
- dataset = [[2.771244718,1.784783929,0],
- [1.728571309,1.169761413,0],
- [3.678319846,2.81281357,0],
- [3.961043357,2.61995032,0],
- [2.999208922,2.209014212,0],
- [7.497545867,3.162953546,1],
- [9.00220326,3.339047188,1],
- [7.444542326,0.476683375,1],
- [10.12493903,3.234550982,1],
- [6.642287351,3.319983761,1]]
- split = get_split(dataset)
- print('Split: [X%d < %.3f]' % ((split['index']+1), split['value']))
優(yōu)化后的 get_split() 函數(shù)能夠輸出每個(gè)分割點(diǎn)及其對應(yīng)的基尼系數(shù)。
運(yùn)行如上的代碼后,它將 print 所有的基尼系數(shù)及其選中的最優(yōu)分割點(diǎn)。在此范例中,它選中了 X1<6.642 作為最終完美分割點(diǎn)(它對應(yīng)的基尼系數(shù)為 0)。
- X1 < 2.771 Gini=0.494
- X1 < 1.729 Gini=0.500
- X1 < 3.678 Gini=0.408
- X1 < 3.961 Gini=0.278
- X1 < 2.999 Gini=0.469
- X1 < 7.498 Gini=0.408
- X1 < 9.002 Gini=0.469
- X1 < 7.445 Gini=0.278
- X1 < 10.125 Gini=0.494
- X1 < 6.642 Gini=0.000
- X2 < 1.785 Gini=1.000
- X2 < 1.170 Gini=0.494
- X2 < 2.813 Gini=0.640
- X2 < 2.620 Gini=0.819
- X2 < 2.209 Gini=0.934
- X2 < 3.163 Gini=0.278
- X2 < 3.339 Gini=0.494
- X2 < 0.477 Gini=0.500
- X2 < 3.235 Gini=0.408
- X2 < 3.320 Gini=0.469
- Split: [X1 < 6.642]
既然我們現(xiàn)在已經(jīng)能夠找出數(shù)據(jù)集中最優(yōu)的分割點(diǎn),那我們現(xiàn)在就來看看我們能如何應(yīng)用它來建立一個(gè)決策樹。
2.3 生成樹模型
創(chuàng)建樹的根節(jié)點(diǎn)(root node)是比較方便的,可以調(diào)用 get_split() 函數(shù)并傳入整個(gè)數(shù)據(jù)集即可達(dá)到此目的。但向樹中增加更多的節(jié)點(diǎn)則比較有趣。
建立樹結(jié)構(gòu)主要分為三個(gè)步驟:
1. 創(chuàng)建終端節(jié)點(diǎn)
2. 遞歸地分割
3. 建構(gòu)整棵樹
2.3.1 創(chuàng)建終端節(jié)點(diǎn)
我們需要決定何時(shí)停止樹的「增長」。
我們可以用兩個(gè)條件進(jìn)行控制:樹的深度和每個(gè)節(jié)點(diǎn)分割后的數(shù)據(jù)點(diǎn)個(gè)數(shù)。
最大樹深度:這代表了樹中從根結(jié)點(diǎn)算起節(jié)點(diǎn)數(shù)目的上限。一旦樹中的節(jié)點(diǎn)樹達(dá)到了這一上界,則算法將會(huì)停止分割數(shù)據(jù)、增加新的節(jié)點(diǎn)。更神的樹會(huì)更為復(fù)雜,也更有可能過擬合訓(xùn)練集。
最小節(jié)點(diǎn)記錄數(shù):這是某節(jié)點(diǎn)分割數(shù)據(jù)后分個(gè)部分?jǐn)?shù)據(jù)個(gè)數(shù)的最小值。一旦達(dá)到或低于該最小值,則算法將會(huì)停止分割數(shù)據(jù)、增加新的節(jié)點(diǎn)。將數(shù)據(jù)集分為只有很少數(shù)據(jù)點(diǎn)的兩個(gè)部分的分割節(jié)點(diǎn)被認(rèn)為太具針對性,并很有可能過擬合訓(xùn)練集。
這兩個(gè)方法基于用戶給定的參數(shù),參與到樹模型的構(gòu)建過程中。
此外,還有一個(gè)情況。算法有可能選擇一個(gè)分割點(diǎn),分割數(shù)據(jù)后所有的數(shù)據(jù)都被分割到同一組內(nèi)(也就是左叉、右叉只有一個(gè)分支上有數(shù)據(jù),另一個(gè)分支沒有)。在這樣的情況下,因?yàn)樵跇涞牧硪粋€(gè)分叉沒有數(shù)據(jù),我們不能繼續(xù)我們的分割與添加節(jié)點(diǎn)的工作。
基于上述內(nèi)容,我們已經(jīng)有一些停止樹「增長」的判別機(jī)制。當(dāng)樹在某一結(jié)點(diǎn)停止增長的時(shí)候,該節(jié)點(diǎn)被稱為終端節(jié)點(diǎn),并被用來進(jìn)行最終預(yù)測。
預(yù)測的過程是通過選擇組表征值進(jìn)行的。當(dāng)遍歷樹進(jìn)入到最終節(jié)點(diǎn)分割后的數(shù)據(jù)組中,算法將會(huì)選擇該組中最普遍出現(xiàn)的值作為預(yù)測值。
如下是一個(gè)名為 to_terminal() 的函數(shù),對每一組收據(jù)它都能選擇一個(gè)表征值。他能夠返回一系列數(shù)據(jù)點(diǎn)中最普遍出現(xiàn)的值。
- # Create a terminal node value
- def to_terminal(group):
- outcomes = [row[-1] for row in group]
- return max(set(outcomes), key=outcomes.count)
2.3.2 遞歸分割
在了解了如何及何時(shí)創(chuàng)建終端節(jié)點(diǎn)后,我們現(xiàn)在可以開始建立樹模型了。
建立樹地模型,需要我們對給定的數(shù)據(jù)集反復(fù)調(diào)用如上定義的 get_split() 函數(shù),不斷創(chuàng)建樹中的節(jié)點(diǎn)。
在已有節(jié)點(diǎn)下加入的新節(jié)點(diǎn)叫做子節(jié)點(diǎn)。對樹中的任意節(jié)點(diǎn)而言,它可能沒有子節(jié)點(diǎn)(則該節(jié)點(diǎn)為終端節(jié)點(diǎn))、一個(gè)子節(jié)點(diǎn)(則該節(jié)點(diǎn)能夠直接進(jìn)行預(yù)測)或兩個(gè)子節(jié)點(diǎn)。在程序中,在表示某節(jié)點(diǎn)的字典中,我們將一棵樹的兩子節(jié)點(diǎn)命名為 left 和 right。
一旦一個(gè)節(jié)點(diǎn)被創(chuàng)建,我們就可以遞歸地對在該節(jié)點(diǎn)被分割得到的兩個(gè)子數(shù)據(jù)集上調(diào)用相同的函數(shù),來分割子數(shù)據(jù)集并創(chuàng)建新的節(jié)點(diǎn)。
如下是一個(gè)實(shí)現(xiàn)該遞歸過程的函數(shù)。它的輸入?yún)?shù)包括:某一節(jié)點(diǎn)(node)、最大樹深度(max_depth)、最小節(jié)點(diǎn)記錄數(shù)(min_size)及當(dāng)前樹深度(depth)。
顯然,一開始運(yùn)行該函數(shù)時(shí),根節(jié)點(diǎn)將被傳入,當(dāng)前深度為 1。函數(shù)的功能分為如下幾步:
1. 首先,該節(jié)點(diǎn)分割的兩部分?jǐn)?shù)據(jù)將被提取出來以便使用,同時(shí)數(shù)據(jù)將被在節(jié)點(diǎn)中刪除(隨著分割工作的逐步進(jìn)行,之前的節(jié)點(diǎn)不需要再使用相應(yīng)的數(shù)據(jù))。
2. 然后,我們將會(huì)檢查該節(jié)點(diǎn)的左叉及右叉的數(shù)據(jù)集是否為空。如果是,則其將會(huì)創(chuàng)建一個(gè)終端節(jié)點(diǎn)。
3. 同時(shí),我們會(huì)檢查是否到達(dá)了最大深度。如果是,則其將會(huì)創(chuàng)建一個(gè)終端節(jié)點(diǎn)。
4. 接著,我們將對左子節(jié)點(diǎn)進(jìn)一步操作。若該組數(shù)據(jù)個(gè)數(shù)小于閾值,則會(huì)創(chuàng)建一個(gè)終端節(jié)點(diǎn)并停止進(jìn)一步操作。否則它將會(huì)以一種深度優(yōu)先的方式創(chuàng)建并添加節(jié)點(diǎn),直到該分叉達(dá)到底部。
5. 對右子節(jié)點(diǎn)同樣進(jìn)行上述操作,不斷增加節(jié)點(diǎn)直到達(dá)到終端節(jié)點(diǎn)。
2.3.3 建構(gòu)整棵樹
我們將所有的內(nèi)容整合到一起。
創(chuàng)建一棵樹包括創(chuàng)建根節(jié)點(diǎn)及遞歸地調(diào)用 split() 函數(shù)來不斷地分割數(shù)據(jù)以構(gòu)建整棵樹。
如下是實(shí)現(xiàn)上述功能的 bulid_tree() 函數(shù)的簡化版本。
- # Build a decision tree
- def build_tree(train, max_depth, min_size):
- root = get_split(dataset)
- split(root, max_depth, min_size, 1)
- return root
我們可以在如上所述的合成數(shù)據(jù)集上測試整個(gè)過程。如下是完整的案例。
在其中還包括了一個(gè) print_tree() 函數(shù),它能夠遞歸地一行一個(gè)地打印出決策樹的節(jié)點(diǎn)。經(jīng)過它打印的不是一個(gè)明顯的樹結(jié)構(gòu),但它能給我們關(guān)于樹結(jié)構(gòu)的大致印象,并能幫助決策。
- # Split a dataset based on an attribute and an attribute value
- def test_split(index, value, dataset):
- left, right = list(), list()
- for row in dataset:
- if row[index] < value:
- left.append(row)
- else:
- right.append(row)
- return left, right
- # Calculate the Gini index for a split dataset
- def gini_index(groups, class_values):
- gini = 0.0
- for class_value in class_values:
- for group in groups:
- size = len(group)
- if size == 0:
- continue
- proportion = [row[-1] for row in group].count(class_value) / float(size)
- gini += (proportion * (1.0 - proportion))
- return gini
- # Select the best split point for a dataset
- def get_split(dataset):
- class_values = list(set(row[-1] for row in dataset))
- b_index, b_value, b_score, b_groups = 999, 999, 999, None
- for index in range(len(dataset[0])-1):
- for row in dataset:
- groups = test_split(index, row[index], dataset)
- gini = gini_index(groups, class_values)
- if gini < b_score:
- b_index, b_value, b_score, b_groups = index, row[index], gini, groups
- return {'index':b_index, 'value':b_value, 'groups':b_groups}
- # Create a terminal node value
- def to_terminal(group):
- outcomes = [row[-1] for row in group]
- return max(set(outcomes), key=outcomes.count)
- # Create child splits for a node or make terminal
- def split(node, max_depth, min_size, depth):
- left, right = node['groups']
- del(node['groups'])
- # check for a no split
- if not left or not right:
- node['left'] = node['right'] = to_terminal(left + right)
- return
- # check for max depth
- if depth >= max_depth:
- node['left'], node['right'] = to_terminal(left), to_terminal(right)
- return
- # process left child
- if len(left) <= min_size:
- node['left'] = to_terminal(left)
- else:
- node['left'] = get_split(left)
- split(node['left'], max_depth, min_size, depth+1)
- # process right child
- if len(right) <= min_size:
- node['right'] = to_terminal(right)
- else:
- node['right'] = get_split(right)
- split(node['right'], max_depth, min_size, depth+1)
- # Build a decision tree
- def build_tree(train, max_depth, min_size):
- root = get_split(dataset)
- split(root, max_depth, min_size, 1)
- return root
- # Print a decision tree
- def print_tree(node, depth=0):
- if isinstance(node, dict):
- print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
- print_tree(node['left'], depth+1)
- print_tree(node['right'], depth+1)
- else:
- print('%s[%s]' % ((depth*' ', node)))
- dataset = [[2.771244718,1.784783929,0],
- [1.728571309,1.169761413,0],
- [3.678319846,2.81281357,0],
- [3.961043357,2.61995032,0],
- [2.999208922,2.209014212,0],
- [7.497545867,3.162953546,1],
- [9.00220326,3.339047188,1],
- [7.444542326,0.476683375,1],
- [10.12493903,3.234550982,1],
- [6.642287351,3.319983761,1]]
- tree = build_tree(dataset, 1, 1)
- print_tree(tree)
在運(yùn)行過程中,我們能修改樹的最大深度,并在打印的樹上觀察其影響。
當(dāng)最大深度為 1 時(shí)(即調(diào)用 build_tree() 函數(shù)時(shí)第二個(gè)參數(shù)),我們可以發(fā)現(xiàn)該樹使用了我們之前發(fā)現(xiàn)的完美分割點(diǎn)(作為樹的唯一分割點(diǎn))。該樹只有一個(gè)節(jié)點(diǎn),也被稱為決策樹樁。
- [X1 < 6.642]
- [0]
- [1]
當(dāng)最大深度加到 2 時(shí),我們迫使輸算法不需要分割的情況下強(qiáng)行分割。結(jié)果是,X1 屬性在左右叉上被使用了兩次來分割這個(gè)本已經(jīng)完美分割的數(shù)據(jù)。
- [X1 < 6.642]
- [X1 < 2.771]
- [0]
- [0]
- [X1 < 7.498]
- [1]
- [1]
最后,我們可以試試最大深度為 3 的情況:
- [X1 < 6.642]
- [X1 < 2.771]
- [0]
- [X1 < 2.771]
- [0]
- [0]
- [X1 < 7.498]
- [X1 < 7.445]
- [1]
- [1]
- [X1 < 7.498]
- [1]
- [1]
這些測試表明,我們可以優(yōu)化代碼來避免不必要的分割。請參見延伸章節(jié)的相關(guān)內(nèi)容。
現(xiàn)在我們已經(jīng)可以(完整地)創(chuàng)建一棵決策樹了,那么我們來看看如何用它來在新數(shù)據(jù)上做出預(yù)測吧。
2.4 利用模型進(jìn)行預(yù)測
使用決策樹模型進(jìn)行決策,需要我們根據(jù)給出的數(shù)據(jù)遍歷整棵決策樹。
與前面相同,我們?nèi)孕枰褂靡粋€(gè)遞歸函數(shù)來實(shí)現(xiàn)該過程。其中,基于某分割點(diǎn)對給出數(shù)據(jù)的影響,相同的預(yù)測規(guī)則被應(yīng)用到左子節(jié)點(diǎn)或右子節(jié)點(diǎn)上。
我們需要檢查對某子節(jié)點(diǎn)而言,它是否是一個(gè)可以被作為預(yù)測結(jié)果返回的終端節(jié)點(diǎn),又或是他是否含有下一層的分割節(jié)點(diǎn)需要被考慮。
如下是實(shí)現(xiàn)上述過程的名為 predict() 函數(shù),你可以看到它是如何處理給定節(jié)點(diǎn)的下標(biāo)與數(shù)值的。
接著,我們使用合成的數(shù)據(jù)集來測試該函數(shù)。如下是一個(gè)使用僅有一個(gè)節(jié)點(diǎn)的硬編碼樹(即決策樹樁)的案例。該案例中對數(shù)據(jù)集中的每個(gè)數(shù)據(jù)進(jìn)行了預(yù)測。
運(yùn)行該例子,它將按照預(yù)期打印出每個(gè)數(shù)據(jù)的預(yù)測結(jié)果。
- Expected=0, Got=0
- Expected=0, Got=0
- Expected=0, Got=0
- Expected=0, Got=0
- Expected=0, Got=0
- Expected=1, Got=1
- Expected=1, Got=1
- Expected=1, Got=1
- Expected=1, Got=1
- Expected=1, Got=1
現(xiàn)在,我們不僅掌握了如何創(chuàng)建一棵決策樹,同時(shí)還知道如何用它進(jìn)行預(yù)測。那么,我們就來試試在實(shí)際數(shù)據(jù)集上來應(yīng)用該算法吧。
2.5 對鈔票數(shù)據(jù)集的案例研究
該節(jié)描述了在鈔票數(shù)據(jù)集上使用了 CART 算法的流程。
第一步是導(dǎo)入數(shù)據(jù),并轉(zhuǎn)換載入的數(shù)據(jù)到數(shù)值形式,使得我們能夠用它來計(jì)算分割點(diǎn)。對此,我們使用了輔助函數(shù) load_csv() 載入數(shù)據(jù)及 str_column_to_float() 以轉(zhuǎn)換字符串?dāng)?shù)據(jù)到浮點(diǎn)數(shù)。
我們將會(huì)使用 5 折交叉驗(yàn)證法(5-fold cross validation)來評估該算法的表現(xiàn)。這也就意味著,對一個(gè)記錄,將會(huì)有 1273/5=274.4 即 270 個(gè)數(shù)據(jù)點(diǎn)。我們將會(huì)使用輔助函數(shù) evaluate_algorithm() 來評估算法在交叉驗(yàn)證集上的表現(xiàn),用 accuracy_metric() 來計(jì)算預(yù)測的準(zhǔn)確率。
完成的代碼如下:
上述使用的參數(shù)包括:max_depth 為 5,min_size 為 10。經(jīng)過了一些實(shí)現(xiàn)后,我們確定了上述 CART 算法的使用的參數(shù),但這不代表所使用的參數(shù)就是最優(yōu)的。
運(yùn)行該案例,它將會(huì) print 出對每一部分?jǐn)?shù)據(jù)的平均分類準(zhǔn)確度及對所有部分?jǐn)?shù)據(jù)的平均表現(xiàn)。
從數(shù)據(jù)中你可以發(fā)現(xiàn),CART 算法選擇的分類設(shè)置,達(dá)到了大約 83% 的平均分類準(zhǔn)確率。其表現(xiàn)遠(yuǎn)遠(yuǎn)好于只有約 50% 正確率的零規(guī)則算法(Zero Rule algorithm)。
Scores: [83.57664233576642, 84.30656934306569, 85.76642335766424, 81.38686131386861, 81.75182481751825]
Mean Accuracy: 83.358%
三、延伸
本節(jié)列出了關(guān)于該節(jié)的延伸項(xiàng)目,你可以根據(jù)此進(jìn)行探索。
1. 算法調(diào)參(Algorithm Tuning):在鈔票數(shù)據(jù)集上使用的 CART 算法未被調(diào)參。你可以嘗試不同的參數(shù)數(shù)值以獲取更好的更優(yōu)的結(jié)果。
2. 交叉熵(Cross Entropy):另一個(gè)用來評估分割點(diǎn)的成本函數(shù)是交叉熵函數(shù)(對數(shù)損失)。你能夠嘗試使用該成本函數(shù)作為替代。
3. 剪枝(Tree Pruning):另一個(gè)減少在訓(xùn)練過程中過擬合程度的重要方法是剪枝。你可以研究并嘗試實(shí)現(xiàn)一些剪枝的方法。
4. 分類數(shù)據(jù)集(Categorical Dataset):在上述例子中,其樹模型被設(shè)計(jì)用于解決數(shù)值型或有序數(shù)據(jù)。你可以嘗試修改樹模型(主要修改分割的屬性,用等式而非排序的形式),使之能夠應(yīng)對分類型的數(shù)據(jù)。
5. 回歸問題(Regression):可以通過使用不同的成本函數(shù)及不同的創(chuàng)建終端節(jié)點(diǎn)的方法,來讓該模型能夠解決一個(gè)回歸問題。
6. 更多數(shù)據(jù)集:你可以嘗試將該算法用于 UCI Machine Learning Repository 上其他的數(shù)據(jù)集。
【本文是51CTO專欄機(jī)構(gòu)機(jī)器之心的原創(chuàng)文章,微信公眾號(hào)“機(jī)器之心( id: almosthuman2014)”】







































