论文精读:Deep Neural Decision Trees

Deep Neural Decision Trees

Soft binning function

Soft binning function 这个函数的功能为:输入一个标量 x x x ,生成标量 x x x 属于的区间的索引。具体如何实现的呢?往下看:

假设我们有一个连续的变量 x x x,我们想把它分隔成 n + 1 n+1 n+1 个间隔。这样就需要 n n n 个切割点(cut points),这 n n n 个切割点是可以训练的变量。将 n n n 个切割点记做 [ β 1 , β 2 , . . . , β n ] [β_1, β_2, . . . , β_n] [β1,β2,...,βn],并且 β 1 < β 2 < ⋅ ⋅ ⋅ < β n . β_1 < β_2 < ··· < β_n. β1<β2<<βn.

我们用 Softmax 作为**函数构造一个单层神经网络:
π = f w , b , τ ( x ) = s o f t m a x ( ( w x + b ) / τ ) π = f_{w,b,τ}(x) = softmax((wx + b)/τ ) π=fw,b,τ(x)=softmax((wx+b)/τ)
这里的 w w w 是常量而不是可以训练的变量。将 w w w 的值记为: $w = [1, 2, . . . , n + 1]. $ b b b 记作:
b = [ 0 , − β 1 , − β 1 − β 2 , . . . , − β 1 − β 2 − ⋅ ⋅ ⋅ − β n ] . b=[0,−β_1,−β_1 −β_2,...,−β_1 −β_2 −···−β_n]. b=[0,β1,β1β2,...,β1β2βn].
并且 $ τ > 0$ 是一个系数. 当 τ → 0 τ → 0 τ0 时输出趋向于一个 one-hot 向量。

举个栗子:假设有三个连续的 logits : o i − 1 , o i , o i + 1 o_{i−1}, o_{i}, o_{i+1} oi1,oi,oi+1 , 当同时满足 o i > o i − 1 o_{i} > o_{i−1} oi>oi1 (即 x > β i x > β_i x>βi) 和 $ o_i > o_{i+1}$ (即 $ x < β_{i+1}$), x x x 就一定落在 ( β i , β i + 1 ) (β_i , β_{i+1} ) (βi,βi+1) 范围内。

比如我们有一个范围为 [ 0 , 1 ] [0,1] [01] 的标量 x x x,两个切割为在 0.33 0.33 0.33 0.66 0.66 0.66,即 β 1 = 0.33 , β 2 = 0.66 β_1=0.33, β_2=0.66 β1=0.33,β2=0.66。那么根据上面两个公式可得到三个 logits: o 1 = x , o 2 = 2 x − 0.33 , o 3 = 3 x − 0.99 o_{1}=x, o_{2}=2x-0.33, o_{3}=3x-0.99 o1=x,o2=2x0.33,o3=3x0.99 。如果 o 2 > o 1 o_{2} > o_{1} o2>o1 那么 2 x − 0.33 > x 2x-0.33 > x 2x0.33>x x > β 1 = 0.33 x > β_1 = 0.33 x>β1=0.33, 如果 o 2 > o 3 o_{2} > o_{3} o2>o3 那么 2 x − 0.33 > 3 x − 0.99 2x-0.33 > 3x - 0.99 2x0.33>3x0.99 x < ( 0.99 − 0.33 ) = ( β 2 − 0.33 ) x < (0.99 - 0.33) = (β_{2} - 0.33) x<(0.990.33)=(β20.33)。这样的话,当满足 o 2 > o 1 o_{2} > o_{1} o2>o1 o 2 > o 3 o_{2} > o_{3} o2>o3 时, x x x 落在区间 ( β 1 , β 2 ) (β_1 , β_{2} ) (β1,β2) 内。

下图可以看到 Soft binning function 的函数曲线:

QQ20201103-211105@2x

x x x 轴是连续输入变量 x ∈ [ 0 , 1 ] x∈[0,1] x[01] 的值。左上:logits 的原始值;右上:应用 τ = 1 τ= 1 τ=1 的 Softmax 函数后的值;左下: τ = 0.1 τ= 0.1 τ=0.1 ;右下: τ = 0.01 τ= 0.01 τ=0.01

通过上图中的左下可以得知,如果 x = 0.15 x = 0.15 x=0.15 ,此时 o 1 > o 2 > o 3 o_1 > o_2 > o_3 o1>o2>o3 ,那么 x > 2 x − 0.33 x > 2x - 0.33 x>2x0.33 0.33 > x 0.33 > x 0.33>x,那么落在了第一个切割点 β 1 β_1 β1 的左面,同理有了这三个曲线,我们就能比较它们在 x x x 取不同值的时候的大小,这样就能确定它们位于哪些分隔点之间。

所以使用这个函数就能根据输入的 x x x 生成近似于 one-hot 的向量,尤其是在小 τ τ τ 时。看上图的右下角, τ τ τ 越来越小的时候函数变得非常置信。当 x x x 在区间 0.4 − 0.6 0.4 - 0.6 0.40.6 区间时, [ o 1 , o 2 , o 3 ] [o_1, o_2, o_3] [o1,o2,o3] 近似等于 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0], 同理在区间 0.0 − 0.4 0.0 - 0.4 0.00.4 时为 [ 1 , 0 , 0 ] [1, 0, 0] [1,0,0], 在区间 0.6 − 1.0 0.6 - 1.0 0.61.0 区间时为 [ 0 , 0 , 1 ] [0, 0, 1] [0,0,1]。这样的话就能和 Soft binning function 这个函数的功能对应上了:输入一个标量 x x x ,生成标量 x x x 属于的区间的索引。

Construct decision tree

有了 binning function,那么还需要用到 Kronecker product ⊗ ⊗ 操作。下图是一个 Kronecker product 的例子:

QQ20201103-231459@2x

假设我们有一个实例输入 x ∈ R D x ∈ R^D xRD D D D 个特征。对 D D D 个特征中每一个特征 x d x_d xd 都进行 binning function 操作通过自己的 neural network f d ( x d ) f_d(x_d) fd(xd),这样我们就能查找最终的节点通过 Kronecker product 操作:
z = f 1 ( x 1 ) ⊗ f 2 ( x 2 ) ⊗ ⋅ ⋅ ⋅ ⊗ f D ( x D ) . z = f_1(x_1) ⊗ f_2(x_2) ⊗ · · · ⊗ f_D(x_D). z=f1(x1)f2(x2)fD(xD).
这里的 z z z 也近似是一个 one-hot 向量来指代 x x x 到达的叶子节点的索引。最后假设每个叶子 z z z 处都有一个线性分类器用来分类到达这里的实例。

下图是在 Iris 数据集上学习到的 DNDT(只用了两个特征:Petal Length 和 Petal Width),其中红色的字体指代的是可以训练的参数,而黑色的字体是常量。下面是训练后的结果,我门使用这颗树进行预测。

2

假设我们有一个新的数据 P e t a l   L e n g t h = 3 , P e t a l   W i d t h = 2 Petal \ Length = 3, Petal \ Width = 2 Petal Length=3,Petal Width=2 输入到下面这颗学习好的神经网络决策树中。计算的流程如下:
f 1 ( 3 ) = s o f t m a x ( ( [ 1 , 2 ] ⋅ 3 + [ 0 , − 2.58 ] ) / τ ) = s o f t m a x ( ( [ 3 , 3.42 ] ) / τ ) f_{1}(3) = softmax(([1, 2]\cdot3 + [0, -2.58])/τ) \\=softmax(([3, 3.42])/τ) f1(3)=softmax(([1,2]3+[0,2.58])/τ)=softmax(([3,3.42])/τ)
τ τ τ 很小的时候, f 1 ( 3 ) f_{1}(3) f1(3) 近似于一个 one-hot 向量 [ 0 , 1 ] [0, 1] [0,1]。同理可得到 f 2 ( 2 ) ≈ [ 0 , 1 ] f_{2}(2) \approx [0, 1] f2(2)[0,1]。使用公式 z = f 1 ( 3 ) ⊗ f 2 ( 2 ) = [ 0 , 1 ] ⊗ [ 0 , 1 ] = [ 0 , 0 , 0 , 1 ] z = f_1(3) ⊗ f_2(2) = [0,1] ⊗ [0,1]=[0,0,0,1] z=f1(3)f2(2)=[0,1][0,1]=[0,0,0,1]

得到的 Kron Product 结果放入一个分类期,得到分类结果:
z ⋅ W = [ 0 , 0 , 0 , 1 ] ⋅ [ [ . . . ] , [ . . . ] , [ . . . ] , [ − 3.24 , − 2.51 , 6.56 ] ] = [ − 3.24 , − 2.51 , 6.56 ] z \cdot W=[0,0,0,1] \cdot [[...],[...],[...],[-3.24,-2.51,6.56]] \\ = [-3.24, -2.51, 6.56] zW=[0,0,0,1][[...],[...],[...],[3.24,2.51,6.56]]=[3.24,2.51,6.56]
对于向量 [ − 3.24 , − 2.51 , 6.56 ] [-3.24, -2.51, 6.56] [3.24,2.51,6.56] 索引位置 3 上的值是最大的,此时可以判断这个新数据是第三分类,即 Virginica。

下图是普通决策树构建的过程。

3

Learning the Tree

现在我们知道了如何找到输入实例的路径,并且分类它。那么训练的时候就需要训练 cut points 和 leaf classifiers。但是由于神经网络 mini-batch 风格的训练,DNDT 可以很好地扩展实例的数量。但是,到目前为止,该设计的一个关键缺点是:由于使用了Kronecker Product,因此就 feature 数量而言无法扩展。在我们目前的实现中,我们通过训练具有随机子空间的森林来避免“宽”数据集的问题 - 但这会以可解释性为代价。

也就是说训练多棵树,每棵树的训练基于所有特征的子集合。子集合的选取是随机的,这样就能通过多棵树把所有特征都考虑了,这样就能变向的解决 “宽” 数据集的问题。

更好的解决方案(可以不借助不可解释的森林的方案)是在训练过程中探索最后 binning function 结果的的稀疏性:非空叶的数量增长比叶总数慢得多。

Experiments

代码:DNDT

下面是实验基于的数据集:

1

对于 BaseLine 模型决策树(DT),我们将两个关键超参数设置为“ gini”,将分割器设置为“ best”。 对于神经网络(NN),我们对所有数据集使用两个包含 50 个神经元的隐藏层的体系结构。 DNDT 还具有一个超参数,即每个要素的切点数量(分支因子),对于数据集,我们将其设置为1。

对于具有 12 个以上特征的数据集,我们使用 DNDT 的 ensemble 版本,其中每棵树随机选择 10 个特征,总共有 10 棵树。 最终的预测是由多数投票给出的。

Results

下面是在这些数据集上三种模型的表现。总体而言,性能最好的模型是 DT。 DT 的良好性能不足为奇,因为这些数据集主要是表格形式的,并且特征维相对较低

传统上,神经网络在此类数据上没有明显的优势。 但是,DNDT 略优于普通神经网络,因为它在设计上更接近决策树。 当然,这只是一个指示性的结果,因为所有这些模型都具有可调整的超参数。 然而,有趣的是,没有任何一种模型具有主导优势。 这让人想起没有免费的午餐定理。

啥是没有免费的午餐定理咱也不知道,可能大佬写文章就喜欢整这些文绉绉的东西吧。

2