关于人工智能:使用Python从头开始手写回归树

32次阅读

共计 7109 个字符,预计需要花费 18 分钟才能阅读完成。

在本篇文章中,咱们将介绍回归树及其根本数学原理,并从头开始应用 Python 实现一个残缺的回归树模型。

为了简略起见这里将应用递归来创立树节点,尽管递归不是一个完满的实现,然而对于解释原理他是最直观的。

首先导入库

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

首先须要创立训练数据,咱们的数据将具备独立变量(x)和一个相干的变量(y),并应用 numpy 在相干值中增加高斯噪声,能够用数学表白为

这里的𝜖 是噪声。代码如下所示。

def f(x):
    mu, sigma = 0, 1.5
    return -x**2 + x + 5 + np.random.normal(mu, sigma, 1)

num_points = 300
np.random.seed(1)
    
x = np.random.uniform(-2, 5, num_points)
y = np.array([f(i) for i in x] )

plt.scatter(x, y, s = 5)

回归树

在回归树中是通过创立一个多个节点的树来预测数值数据的。下图展现了一个回归树的树结构示例,其中每个节点都有其用于划分数据的阈值。

给定一组数据,输出值将通过相应的规格达到叶子节点。达到节点 M 的所有输出值能够用 X 的子集示意。从数学上讲,让咱们用一个函数表白此状况,如果给定的输出值达到节点 M,则能够给出 1 个,否则为 0。

找到决裂数据的阈值:通过在每个步骤中抉择 2 个间断点并计算其平均值来迭代训练数据。计算的平均值将数据分为两个的阈值。

首先让咱们思考随机阈值以演示任何给定的状况。

threshold = 1.5

low = np.take(y, np.where(x < threshold))
high = np.take(y, np.where(x > threshold))

plt.scatter(x, y, s = 5, label = 'Data')
plt.plot([threshold]*2, [-16, 10], 'b--', label = 'Threshold line')
plt.plot([-2, threshold], [low.mean()]*2, 'r--', label = 'Left child prediction line')
plt.plot([threshold, 5], [high.mean()]*2, 'r--', label = 'Right child prediction line')
plt.plot([-2, 5], [y.mean()]*2, 'g--', label = 'Node prediction line')
plt.legend()

蓝色垂直线示意单个阈值,咱们假如它是任意两点的均值, 并稍后将其用于划分数据。

咱们对这个问题的第一个预测是所有训练数据 (y 轴) 的平均值(绿色水平线)。而两条红线是要创立的子节点的预测。

很显著这些平均值都不能很好地代表咱们的数据,但它们的差别也是很显著的:主节点预测 (绿线) 失去所有训练数据的均值,咱们将其分为 2 个子节点,这 2 个子节点有本人的预测(红线)。与绿线相比这 2 个子节点更好地代表了它们对应的训练数据。回归树就是将一直地将数据分成 2 个局部——从每个节点创立 2 个子节点,直到达到给定的进行值(这是一个节点所能领有的最小数据量)。它会提前进行树的构建过程,咱们将其称为预修剪树。

为什么会有早停的机制?如果咱们要持续进行调配直到节点只有一个值是,这创立一个适度拟合的计划,每个训练数据都只能预测本人。

阐明:当模型实现时,它不会应用根节点或任何两头节点来预测任何值; 它将应用回归树的叶子 (这将是树的最初一个节点) 进行预测。

为了失去最能代表给定阈值数据的阈值,咱们应用残差平方和。它能够在数学上定义为

让咱们看看这一步是如何工作的。

既然计算了阈值的 SSR 值,那么能够采纳具备最小 SSR 值的阈值。应用该阈值将训练数据分为两个(低和高局部),其中其中低局部将用于创立左子节点,高局部将用于创立右子节点。

def SSR(r, y): 
    return np.sum((r - y)**2 )
    
SSRs, thresholds = [], []
for i in range(len(x) - 1):
    threshold = x[i:i+2].mean()
    
    low = np.take(y, np.where(x < threshold))
    high = np.take(y, np.where(x > threshold))
    
    guess_low = low.mean()
    guess_high = high.mean()
    
    SSRs.append(SSR(low, guess_low) + SSR(high, guess_high))
    thresholds.append(threshold)
    
print('Minimum residual is: {:.2f}'.format(min(SSRs)))
print('Corresponding threshold value is: {:.4f}'.format(thresholds[SSRs.index(min(SSRs))]))

在进入下一步之前,我将应用 pandas 创立一个 df,并创立一个用于寻找最佳阈值的办法。所有这些步骤都能够在没有 pandas 的状况下实现,这里应用他是因为比拟不便。

df = pd.DataFrame(zip(x, y.squeeze()), columns = ['x', 'y'])

def find_threshold(df, plot = False):
    SSRs, thresholds = [], []
    for i in range(len(df) - 1):
        threshold = df.x[i:i+2].mean()

        low = df[(df.x <= threshold)]
        high = df[(df.x > threshold)]

        guess_low = low.y.mean()
        guess_high = high.y.mean()

        SSRs.append(SSR(low.y.to_numpy(), guess_low) + SSR(high.y.to_numpy(), guess_high))
        thresholds.append(threshold)
    
    if plot:
        plt.scatter(thresholds, SSRs, s = 3)
        plt.show()
        
    return thresholds[SSRs.index(min(SSRs))]

创立子节点

在将数据分成两个局部后就能够为低值和高值找到独自的阈值。须要留神的是这里要减少一个进行条件; 因为对于每个节点,属于该节点的数据集中的点会变少,所以咱们为每个节点定义了最小数据点数量。如果不这样做,每个节点将只应用一个训练值进行预测,会导致过拟合。

能够递归地创立节点,咱们定义了一个名为 TreeNode 的类,它将存储节点应该存储的每一个值。应用这个类咱们首先创立根,同时计算它的阈值和预测值。而后递归地创立它的子节点,其中每个子节点类都存储在父类的 left 或 right 属性中。

在上面的 create_nodes 办法中,首先将给定的 df 分成两局部。而后查看是否有足够的数据独自创立左右节点。如果 (对于其中任何一个) 有足够的数据点,咱们计算阈值并应用它创立一个子节点,用这个新节点作为树再次调用 create_nodes 办法。

class TreeNode():
    def __init__(self, threshold, pred):
        self.threshold = threshold
        self.pred = pred
        self.left = None
        self.right = None

def create_nodes(tree, df, stop):
    low = df[df.x <= tree.threshold]
    high = df[df.x > tree.threshold]
    
    if len(low) > stop:
        threshold = find_threshold(low)
        tree.left = TreeNode(threshold, low.y.mean())
        create_nodes(tree.left, low, stop)
        
    if len(high) > stop:
        threshold = find_threshold(high)
        tree.right = TreeNode(threshold, high.y.mean())
        create_nodes(tree.right, high, stop)
        
threshold = find_threshold(df)
tree = TreeNode(threshold, df.y.mean())

create_nodes(tree, df, 5)

这个办法在第一棵树上进行了批改,因为它不须要返回任何货色。尽管递归函数通常不是这样写的(不返回),但因为不须要返回值,所以当没有激活 if 语句时,不做任何操作。

在实现后能够查看此树结构,查看它是否创立了一些能够拟合数据的节点。这里将手动抉择第一个节点及其对根阈值的预测。

plt.scatter(x, y, s = 0.5, label = 'Data')
plt.plot([tree.threshold]*2, [-16, 10], 'r--', 
         label = 'Root threshold')
plt.plot([tree.right.threshold]*2, [-16, 10], 'g--', 
         label = 'Right node threshold')
plt.plot([tree.threshold, tree.right.threshold], 
         [tree.right.left.pred]*2,
         'g', label = 'Right node prediction')
plt.plot([tree.left.threshold]*2, [-16, 10], 'm--', 
         label = 'Left node threshold')
plt.plot([tree.left.threshold, tree.threshold], 
         [tree.left.right.pred]*2,
         'm', label = 'Left node prediction')
plt.plot([tree.left.left.threshold]*2, [-16, 10], 'k--',
         label = 'Second Left node threshold')
plt.legend()

这里看到了两个预测:

第一个左节点对高值的预测(高于其阈值)

第一个右节点对低值 (低于其阈值) 的预测

这里我手动剪切了预测线的宽度,因为如果给定的 x 值达到了这些节点中的任何一个,则将以属于该节点的所有 x 值的平均值示意,这也意味着没有其余 x 值参加 在该节点的预测中(心愿有意义)。

这种树形构造远不止两个节点那么简略,所以咱们能够通过如下调用它的子节点来查看一个特定的叶子节点。

tree.left.right.left.left

这当然意味着这里有一个向下 4 个子结点长的分支,但它能够在树的另一个分支上深刻得多。

预测

咱们能够创立一个预测办法来预测任何给定的值。

def predict(x):
    curr_node = tree
    result = None
    while True:
        if x <= curr_node.threshold:
            if curr_node.left: curr_node = curr_node.left
            else: 
                break
        elif x > curr_node.threshold:
            if curr_node.right: curr_node = curr_node.right
            else: 
                break
                
    return curr_node.pred

预测办法做的是沿着树向下,通过比拟咱们的输出和每个叶子的阈值。如果输出值大于阈值,则转到右叶,如果小于阈值,则转到左叶,以此类推,直到达到任何底部叶子节点。而后应用该节点本身的预测值进行预测,并与其阈值进行最初的比拟。

应用 x = 3 进行测试(在创立数据时,能够应用下面所写的函数计算理论值。-3**2+3+5 = -1,这是期望值),咱们失去:

predict(3)
# -1.23741

计算误差

这里用绝对平方误差验证数据

def RSE(y, g): 
    return sum(np.square(y - g)) / sum(np.square(y - 1 / len(y)*sum(y)))

x_val = np.random.uniform(-2, 5, 50)
y_val = np.array([f(i) for i in x_val] ).squeeze()

tr_preds = np.array([predict(i) for i in df.x] )
val_preds = np.array([predict(i) for i in x_val] )
print('Training error: {:.4f}'.format(RSE(df.y, tr_preds)))
print('Validation error: {:.4f}'.format(RSE(y_val, val_preds)))

能够看到误差并不大,后果如下

概括的步骤

更深刻的模型

一个更适宜回归树模型的数据:因为咱们的数据是多项式生成的数据,所以应用多项式回归模型能够更好地拟合。咱们更换一下训练数据,把新函数设为

def f(x):
    mu, sigma = 0, 0.5
    if x < 3: return 1 + np.random.normal(mu, sigma, 1)
    elif x >= 3 and x < 6: return 9 + np.random.normal(mu, sigma, 1)
    elif x >= 6: return 5 + np.random.normal(mu, sigma, 1)
    
np.random.seed(1)
    
x = np.random.uniform(0, 10, num_points)
y = np.array([f(i) for i in x] )

plt.scatter(x, y, s = 5)

在此数据集上运行了下面的所有雷同过程,后果如下

比咱们从多项式数据中取得的误差低。

最初共享一下下面动图的代码:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

#===================================================Create Data
def f(x):
    mu, sigma = 0, 1.5
    return -x**2 + x + 5 + np.random.normal(mu, sigma, 1)

np.random.seed(1)
    
x = np.random.uniform(-2, 5, 300)
y = np.array([f(i) for i in x] )

p = x.argsort()
x = x[p]
y = y[p]

#===================================================Calculate Thresholds
def SSR(r, y): #send numpy array
    return np.sum((r - y)**2 )

SSRs, thresholds = [], []
for i in range(len(x) - 1):
    threshold = x[i:i+2].mean()
    
    low = np.take(y, np.where(x < threshold))
    high = np.take(y, np.where(x > threshold))
    
    guess_low = low.mean()
    guess_high = high.mean()
    
    SSRs.append(SSR(low, guess_low) + SSR(high, guess_high))
    thresholds.append(threshold)

#===================================================Animated Plot
fig, (ax1, ax2) = plt.subplots(2,1, sharex = True)
x_data, y_data = [], []
x_data2, y_data2 = [], []
ln, = ax1.plot([], [], 'r--')
ln2, = ax2.plot(thresholds, SSRs, 'ro', markersize = 2)
line = [ln, ln2]

def init():
    ax1.scatter(x, y, s = 3)
    ax1.title.set_text('Trying Different Thresholds')
    ax2.title.set_text('Threshold vs SSR')
    ax1.set_ylabel('y values')
    ax2.set_xlabel('Threshold')
    ax2.set_ylabel('SSR')
    return line

def update(frame):
    x_data = [x[frame:frame+2].mean()] * 2
    y_data = [min(y), max(y)]
    line[0].set_data(x_data, y_data)

    x_data2.append(thresholds[frame])
    y_data2.append(SSRs[frame])
    line[1].set_data(x_data2, y_data2)
    return line

ani = FuncAnimation(fig, update, frames = 298,
                    init_func = init, blit = True)
plt.show()

https://avoid.overfit.cn/post/68d76a2540894366bb7033ff120a30d6

作者:Berat Yildirim

正文完
 0