共计 14506 个字符,预计需要花费 37 分钟才能阅读完成。
作者简介
韩明聪,TiDB Contributor,上海交通大学 IPADS 实验室博士研究生,钻研方向为系统软件。本文次要介绍了如何在 TiDB 中应用纯 SQL 训练一个机器学习模型。
前言
家喻户晓,TiDB 5.1 版本减少了很多新个性,其中有一个个性,即 ANSI SQL 99 规范中的 Common Table Expression(CTE)。一般来说,CTE 能够被用作一个 Statement 作用于长期的 View,将一个简单的 SQL 解耦,进步开发效率。然而,CTE 还有一个重要的应用形式,即 Recursive CTE,容许 CTE 援用本身,这是欠缺 SQL 性能的最初一块外围的拼图。在 StackOverflow 中有过这样一个探讨“Is SQL or even TSQL Turing Complete”,其中点赞最多的回复中提到这样一句话:
“In this set of slides Andrew Gierth proves that with CTE and Windowing SQL is Turing Complete, by constructing a cyclic tag system , which has been proved to be Turing Complete. The CTE feature is the important part however – it allows you to create named sub-expressions that can refer to themselves, and thereby recursively solve problems.”
即 CTE 和 Window Function 甚至使得 SQL 成为一个图灵齐备的语言。而这又让我想起来多年前看到过的一篇文章 Deep Neural Network implemented in pure SQL over BigQuery,作者应用纯 SQL 来实现了一个 DNN 模型,然而关上 repo 后发现,他居然是题目党!实际上他还是应用了 Python 来实现迭代训练。因而,既然 Recursive CTE 给了咱们“迭代”的能力,这让我想挑战一下,是否在 TiDB 中 应用纯 SQL 实现机器学习模型的训练、推理。
Iris Dataset
首先要抉择一个简略的机器学习模型和工作,咱们先尝试 sklearn 中的入门数据集 iris dataset。这个数据集共蕴含 3 类 150 条记录,每类各 50 个数据,每条记录都有 4 项特色:花萼长度、花萼宽度、花瓣长度、花瓣宽度,能够通过这 4 个特色预测鸢尾花卉属于 iris-setosa,iris-versicolour,iris-virginica 中的哪一种类。
当下载好数据后(曾经是 CSV 格局),咱们先将数据导入到 TiDB 中。
mysql> create table iris(sl float, sw float, pl float, pw float, type varchar(16));
mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO TABLE iris FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' ;
mysql> select * from iris limit 10;+------+------+------+------+-------------+| sl | sw | pl | pw | type |+------+------+------+------+-------------+| 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa || 4.9 | 3 | 1.4 | 0.2 | Iris-setosa || 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa || 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa || 5 | 3.6 | 1.4 | 0.2 | Iris-setosa || 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa || 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa || 5 | 3.4 | 1.5 | 0.2 | Iris-setosa || 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa || 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |+------+------+------+------+-------------+10 rows in set (0.00 sec)
mysql> select type, count(*) from iris group by type;+-----------------+----------+| type | count(*) |+-----------------+----------+| Iris-versicolor | 50 || Iris-setosa | 50 || Iris-virginica | 50 |+-----------------+----------+3 rows in set (0.00 sec)
Softmax Logistic Regression
这里咱们抉择一个简略的机器学习模型 —— Softmax 逻辑回归,来实现多分类。(以下的图与介绍均来自百度百科)
在 Softmax 回归中将 x 分类为类别 y 的概率为:
代价函数为:
能够求得梯度:
因而能够通过梯度降落办法,每次更新梯度:
Model Inference
咱们先写一个 SQL 来实现 Inference,依据下面定义的模型和数据,输出的数据 X 共有五维(sl,sw,pl,pw 以及一个常数 1.0),输入应用 one-hot 编码。
mysql> create table data(x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30), y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30));
mysql>insert into dataselect sl, sw, pl, pw, 1.0, case when type='Iris-setosa'then 1 else 0 end, case when type='Iris-versicolor'then 1 else 0 end, case when type='Iris-virginica'then 1 else 0 endfrom iris;
参数共有 3 类 * 5 维 = 15 个:
mysql> create table weight(w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30), w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30), w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));
先全副初始化为 0.1,0.2,0.3(这里抉择不同的数字是为了不便演示,也能够全副初始化为 0.1):
mysql> insert into weight values (0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3);
上面咱们写一个 SQL 来统计对所有的 Data 进行 Inference 后后果的 准确率。
为了不便了解,咱们先给一个伪代码形容这个过程:
weight = (w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24)for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data: exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04) exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14) exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24) sum_exp = exp0 + exp1 + exp2 // softmax p0 = exp0 sum_exp p1 = exp1 sum_exp p2 = exp2 sum_exp // inference result r0 = p0 > p1 and p0 > p2 r1 = p1 > p0 and p1 > p2 r2 = p2 > p0 and p2 > p1 data.correct = (y0 == r0 and y1 == r1 and y2 == r2)return sum(Data.correct) count(Data)
在上述代码中,咱们对 Data 中的每一行元素进行计算,首先求三个向量点乘的 exp,而后求 softmax,最初抉择 p0, p1, p2 中最大的为 1,其余为 0,这样就实现了一个样本的 Inference。如果一个样本最初 Inference 的后果与它原本的分类统一,那就是一次正确的预测,最初咱们对所有样本中正确的数量求和,即可失去最初的正确率。
上面给出 SQL 的实现,咱们抉择把 data 中的每一行数据都和 weight(只有一行数据)join 起来,而后计算每一行数据的 Inference 后果,再对正确的样本数量求和:
select sum(y0 = r0 and y1 = r1 and y2 = r2) count(*)from (select y0, y1, y2, p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2 from (select y0, y1, y2, e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1, e2/(e0+e1+e2) as p2 from (select y0, y1, y2, exp( w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4) as e0, exp(w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4) as e1, exp(w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4) as e2 from data, weight) t1 )t2 )t3;
能够看到上述 SQL 简直是按步骤实现了上述伪代码的计算过程,失去后果:
+-----------------------------------------------+| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |+-----------------------------------------------+| 0.3333 |+-----------------------------------------------+1 row in set (0.01 sec)
上面咱们就对模型的参数进行学习。
Model Training
Notice:这里为了简化问题,不思考“训练集”、“验证集”等问题,只应用全副的数据进行训练。
咱们还是先给出一个伪代码,而后依据伪代码写出一个 SQL:
weight = (w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24)for iter in iterations: sum00 = 0 sum01 = 0 ... sum23 = 0 sum24 = 0 for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data: exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04) exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14) exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24) sum_exp = exp0 + exp1 + exp2 // softmax p0 = y0 - exp0 sum_exp p1 = y1 - exp1 sum_exp p2 = y2 - exp2 sum_exp sum00 += p0 * x0 sum01 += p0 * x1 sum02 += p0 * x2 ... sum23 += p2 * x3 sum24 += p2 * x4 w00 = w00 + learning_rate * sum00 Data.size w01 = w01 + learning_rate * sum01 Data.size ... w23 = w23 + learning_rate * sum23 Data.size w24 = w24 + learning_rate * sum24 Data.size
看上去比拟繁琐,因为咱们这里抉择把 sum, w 等向量给手动开展。
接着咱们开始写 SQL 训练,咱们先写 只有一次迭代 的 SQL:
设置学习率和样本数量
mysql> set @lr = 0.1;Query OK, 0 rows affected (0.00 sec)mysql> set @dsize = 150;Query OK, 0 rows affected (0.00 sec)
迭代一次:
select w00 + @lr * sum(d00) @dsize as w00, w01 + @lr * sum(d01) @dsize as w01, w02 + @lr * sum(d02) @dsize as w02, w03 + @lr * sum(d03) @dsize as w03, w04 + @lr * sum(d04) @dsize as w04 , w10 + @lr * sum(d10) @dsize as w10, w11 + @lr * sum(d11) @dsize as w11, w12 + @lr * sum(d12) @dsize as w12, w13 + @lr * sum(d13) @dsize as w13, w14 + @lr * sum(d14) @dsize as w14, w20 + @lr * sum(d20) @dsize as w20, w21 + @lr * sum(d21) @dsize as w21, w22 + @lr * sum(d22) @dsize as w22, w23 + @lr * sum(d23) @dsize as w23, w24 + @lr * sum(d24) @dsize as w24from (select w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04, p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14, p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24 from (select w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, x0, x1, x2, x3, x4, y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2 from (select w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, x0, x1, x2, x3, x4, y0, y1, y2, exp( w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4) as e0, exp(w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4) as e1, exp(w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4) as e2 from data, weight) t1 )t2 )t3;
失去的后果是一次迭代后的模型参数:
+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+1 row in set (0.03 sec)
上面就是外围局部,咱们应用 Recursive CTE 来进行迭代训练:
mysql> set @num_iterations = 1000;Query OK, 0 rows affected (0.00 sec)
外围的思路是,每次迭代的输出都是上一次迭代的后果,而后咱们再加一个递增的迭代变量来管制迭代次数,大体的架构:
with recursive cte(iter, weight) as(select 1, init_weightunion allselect iter+1, new_weightfrom cte where ites < @num_iterations)
接着,咱们把一次迭代的 SQL 和这个迭代的框架联合起来(为了进步计算精度,在两头后果里退出了一些类型转换):
with recursive weight(iter, w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24) as(select 1, cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))union allselect iter + 1, w00 + @lr * cast(sum(d00) as DECIMAL(35, 30)) @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30)) @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30)) @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30)) @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30)) @dsize as w04 , w10 + @lr * cast(sum(d10) as DECIMAL(35, 30)) @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30)) @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30)) @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30)) @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30)) @dsize as w14, w20 + @lr * cast(sum(d20) as DECIMAL(35, 30)) @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30)) @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30)) @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30)) @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30)) @dsize as w24 from (select iter, w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04, p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14, p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24 from (select iter, w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, x0, x1, x2, x3, x4, y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2 from (select iter, w00, w01, w02, w03, w04, w10, w11, w12, w13, w14, w20, w21, w22, w23, w24, x0, x1, x2, x3, x4, y0, y1, y2, exp( w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4) as e0,
exp(w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4) as e1,
exp(w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4) as e2
from data, weight where iter < @num_iterations) t1
)t2
)t3
having count(*) > 0
)
select * from weight where iter = @num_iterations;
这个版本和下面迭代一次的版本的区别在于两点:
在 data join weight 后,咱们减少一个 where iter < @num_iterations
用于管制迭代次数,并且在最初的输入中减少了一列 iter + 1 as ite
;
最初咱们还减少了 having count(*) > 0
,防止当最初没有输出数据时,aggregation 还是会输入数据,导致迭代不能完结。
而后咱们失去后果:
ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery
啊这…… recursive cte 居然不容许在 recursive part 里有子查问!不过把下面的子查问全副都合并到一起也不是不能够,那我手动合并一下,而后再试一下:
ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block
不容许子查问我能够手动改 SQL,然而不容许用 aggregate function 我是真的没方法了!
在这里咱们只能发表挑战失败…诶,为啥我不能去改一下 TiDB 的实现呢?
依据 proposal 中的介绍,recursive CTE 的实现并没有脱离 TiDB 根本的执行框架,征询了 @wjhuang2016 之后,得悉之所以不容许应用子查问和 aggregate function 的起因应该有两点:
- MySQL 也不容许
- 如果容许的话,有很多的 corner case 须要解决,十分的简单
然而这里咱们只是须要试验一下性能,临时把这个 check 给删除掉也未尝不可,diff 里删除了对子查问和 aggregation function 的查看。
上面咱们再次执行一遍:
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
| iter | w00 | w01 | w02 | w03 | w04 | w10 | w11 | w12 | w13 | w14 | w20 | w21 | w22 | w23 | w24 |
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
| 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+
胜利了!咱们失去了迭代 1000 次后的参数!
上面咱们用新的参数来从新计算正确率:
| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |
+-------------------------------------------------+
| 0.9867 |
+-------------------------------------------------+
1 row in set (0.02 sec)
这次正确率达到了 98%。
Conclusion
咱们这次胜利应用纯 SQL 在 TiDB 中训练了一个 Softmax logistic regression model,次要利用了 TiDB v5.1 版本的 Recursive CTE 性能。在测试的过程中,咱们发现了目前 TiDB 的 Recursive CTE 不容许存在 subquery 和 aggregate function,咱们简略批改了 TiDB 的代码,绕过了这个限度,最终胜利训练出了一个模型,并在 iris dataset 上失去了 98% 的准确率。
Discussion
- 通过一些测试后,发现 PostgreSQL 和 MySQL 均不反对在 Recursive CTE 应用聚合函数,可能实现起来的确存在一些难以解决的 corner case,具体大家能够讨论一下。
- 本次的尝试,是手动把所有的维度全副开展,实际上我还写了一个不须要开展所有维度的实现(例如 data 表的 schema 是 (idx, dim, value)),然而这种实现形式须要 join 两次 weight 表,也就是在 CTE 里须要递归拜访两次,这还须要批改 TiDB Executor 的实现,所以就没有写在这里。但实际上,这种实现形式更加的通用,一个 SQL 能够解决所有维度数量的模型(我最后想尝试用 TiDB 训练 MINIST)。